diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py new file mode 100644 index 000000000000..0d2ca310bce8 --- /dev/null +++ b/tests/lax_autodiff_test.py @@ -0,0 +1,1044 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +from functools import partial +import itertools +from unittest import SkipTest + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as onp + +import jax +from jax import api +from jax import core +from jax import dtypes +from jax import lax +from jax import test_util as jtu +from jax.test_util import check_grads + +from tests.lax_test import (CombosWithReplacement, compatible_shapes, + complex_dtypes, float_dtypes, inexact_dtypes, + num_float_bits) + +from jax.config import config +config.parse_flags_with_absl() +FLAGS = config.FLAGS + + +GradTestSpec = collections.namedtuple( + "GradTestSpec", + ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) +def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): + return GradTestSpec( + op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + +grad_float_dtypes = list(jtu.supported_dtypes().intersection( + {onp.float32, onp.float64})) +grad_complex_dtypes = list(jtu.supported_dtypes().intersection( + {onp.complex64, onp.complex128})) +grad_inexact_dtypes = grad_float_dtypes + grad_complex_dtypes + +LAX_GRAD_OPS = [ + grad_test_spec(lax.neg, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.floor, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4), + dtypes=grad_float_dtypes), + grad_test_spec(lax.ceil, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4), + dtypes=grad_float_dtypes), + grad_test_spec(lax.round, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4), + dtypes=grad_float_dtypes), + + grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.log1p, nargs=1, order=2, rng_factory=jtu.rand_positive, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.sinh, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes + [onp.complex64], tol=1e-5), + grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes, tol=1e-5), + grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes, tol=1e-5), + grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes, tol={onp.float32: 5e-1}), + grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.tan, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3), + dtypes=grad_inexact_dtypes, tol=1e-3), + grad_test_spec(lax.asin, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3), + dtypes=grad_float_dtypes, tol=1e-3), + grad_test_spec(lax.acos, nargs=1, order=2, + rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3), + dtypes=grad_float_dtypes, tol=2e-2), + # TODO(proteneer): atan2 input is already a representation of a + # complex number. Need to think harder about what this even means + # if each input itself is a complex number. + grad_test_spec(lax.atan2, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes), + + grad_test_spec(lax.erf, nargs=1, order=2, rng_factory=jtu.rand_small, + dtypes=grad_float_dtypes), + grad_test_spec(lax.erfc, nargs=1, order=2, rng_factory=jtu.rand_small, + dtypes=grad_float_dtypes), + grad_test_spec(lax.erf_inv, nargs=1, order=2, rng_factory=jtu.rand_small, + dtypes=grad_float_dtypes), + # grad_test_spec(lax.lgamma, nargs=1, order=2, rng_factory=jtu.rand_small, + # dtypes=grad_float_dtypes), # TODO(mattjj): enable + grad_test_spec(lax.bessel_i0e, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes), + grad_test_spec(lax.bessel_i1e, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes), + + grad_test_spec(lax.real, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_complex_dtypes), + grad_test_spec(lax.imag, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_complex_dtypes), + grad_test_spec(lax.complex, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes), + grad_test_spec(lax.conj, nargs=1, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.abs, nargs=1, order=2, rng_factory=jtu.rand_positive, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.pow, nargs=2, order=2, rng_factory=jtu.rand_positive, + dtypes=grad_inexact_dtypes, tol={onp.float32: 3e-1}), + + grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.sub, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.mul, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_inexact_dtypes), + grad_test_spec(lax.div, nargs=2, order=1, rng_factory=jtu.rand_not_small, + dtypes=grad_inexact_dtypes), + + grad_test_spec(lax.max, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes), + grad_test_spec(lax.min, nargs=2, order=2, rng_factory=jtu.rand_default, + dtypes=grad_float_dtypes), + # TODO(mattjj): make some-equal checks more robust, enable second-order + # grad_test_spec(lax.max, nargs=2, order=1, rng_factory=jtu.rand_some_equal, + # dtypes=grad_float_dtypes, name="MaxSomeEqual"), + # grad_test_spec(lax.min, nargs=2, order=1, rng_factory=jtu.rand_some_equal, + # dtypes=grad_float_dtypes, name="MinSomeEqual"), +] + +GradSpecialValuesTestSpec = collections.namedtuple( + "GradSpecialValuesTestSpec", ["op", "values", "tol"]) +def grad_special_values_test_spec(op, values, tol=None): + return GradSpecialValuesTestSpec(op, values, tol) + +LAX_GRAD_SPECIAL_VALUE_TESTS = [ + grad_special_values_test_spec( + lax.sinh, [0.], + tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None), + grad_special_values_test_spec( + lax.cosh, [0.], + tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None), + grad_special_values_test_spec(lax.tanh, [0., 1000.]), + grad_special_values_test_spec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]), + grad_special_values_test_spec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]), + grad_special_values_test_spec(lax.tan, [0.]), + grad_special_values_test_spec(lax.asin, [0.]), + grad_special_values_test_spec(lax.acos, [0.]), + grad_special_values_test_spec(lax.atan, [0., 1000.]), + grad_special_values_test_spec(lax.erf, [0., 10.]), + grad_special_values_test_spec(lax.erfc, [0., 10.]), +] + + +def check_grads_bilinear(f, args, order, + modes=["fwd", "rev"], atol=None, rtol=None): + # Can use large eps to make up for numerical inaccuracies since the op is + # bilinear (relying on the fact that we only check one arg at a time) + lhs, rhs = args + check_grads(lambda lhs: f(lhs, rhs), (lhs,), order, + modes=modes, atol=atol, rtol=rtol, eps=1.) + check_grads(lambda rhs: f(lhs, rhs), (rhs,), order, + modes=modes, atol=atol, rtol=rtol, eps=1.) + +class LaxAutodiffTest(jtu.JaxTestCase): + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix( + rec.name, shapes, itertools.repeat(dtype)), + "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, + "order": rec.order, "tol": rec.tol} + for shape_group in compatible_shapes + for shapes in CombosWithReplacement(shape_group, rec.nargs) + for dtype in rec.dtypes) + for rec in LAX_GRAD_OPS)) + def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): + rng = rng_factory(self.rng()) + if jtu.device_under_test() == "tpu" and op is lax.pow: + raise SkipTest("pow grad imprecise on tpu") + tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol + args = tuple(rng(shape, dtype) for shape in shapes) + check_grads(op, args, order, ["fwd", "rev"], tol, tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), + "op": rec.op, "special_value": special_value, "tol": rec.tol} + for special_value in rec.values) + for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) + def testOpGradSpecialValue(self, op, special_value, tol): + check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_from_dtype={}_to_dtype={}".format( + jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), + "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory} + for from_dtype, to_dtype in itertools.product( + float_dtypes + complex_dtypes, repeat=2) + for rng_factory in [jtu.rand_default])) + def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): + rng = rng_factory(self.rng()) + tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), + jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) + args = (rng((2, 3), from_dtype),) + convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) + convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)( + convert_element_type) + check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype, + "rng_factory": rng_factory} + for shape in [(), (2, 3)] + for dtype in grad_float_dtypes + for rng_factory in [jtu.rand_default])) + def testClampGrad(self, shape, dtype, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + low = operand - dtype(10) + high = operand + dtype(10) + # Avoids points near the boundary where the gradient may be inaccurate. + check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) + check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) + check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( + dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name, + num_arrs), + "dim": dim, "base_shape": base_shape, "dtype": dtype, + "num_arrs": num_arrs, "rng_factory": rng_factory} + for num_arrs in [3] + for dtype in float_dtypes + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for dim in range(len(base_shape)) + for rng_factory in [jtu.rand_default])) + def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): + rng = rng_factory(self.rng()) + shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] + operands = tuple(rng(shape, dtype) for shape in shapes) + concatenate = lambda *args: lax.concatenate(args, dim) + check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + strides, padding), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "strides": strides, "padding": padding, "rng_factory": rng_factory,} + for lhs_shape, rhs_shape, all_strides in itertools.chain( + [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) + for b, i, j in itertools.product([2, 3], repeat=3)], + [((4, 2, 1), (3, 2, 1), [(1,)])]) + for strides in all_strides + for dtype in float_dtypes + for padding in ["VALID", "SAME"] + for rng_factory in [jtu.rand_small])) + def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): + rng = rng_factory(self.rng()) + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + conv = partial(lax.conv, window_strides=strides, padding=padding, + precision=lax.Precision.HIGHEST) + check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], + atol=1e-2, rtol=1e-2) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" + "rhs_dilation={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + strides, padding, lhs_dil, rhs_dil), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "strides": strides, "padding": padding, "lhs_dil": lhs_dil, + "rhs_dil": rhs_dil, "rng_factory": rng_factory} + for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in + itertools.chain( + [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], + [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], + [(1, 1), (2, 1)], [(1, 1)]) + for b, i, j in itertools.product([2, 3], repeat=3)], + [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], + [(1,), (2,)], [(1,), (2,)])]) + for strides in all_strides + for rhs_dil in rhs_dils + for lhs_dil in lhs_dils + for dtype in float_dtypes + for padding in all_pads + for rng_factory in [jtu.rand_small])) + def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, + padding, lhs_dil, rhs_dil, rng_factory): + rng = rng_factory(self.rng()) + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + conv = partial(lax.conv_with_general_padding, window_strides=strides, + padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, + precision=lax.Precision.HIGHEST) + check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], + atol=1e-2, rtol=1e-2) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" + "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), + feature_group_count, batch_group_count), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "strides": strides, "padding": padding, "lhs_dil": lhs_dil, + "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, + "perms": perms, "feature_group_count": feature_group_count, + "batch_group_count": batch_group_count} + for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) + for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ + ([(b * batch_group_count, i * feature_group_count, 6, 7), + (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape + (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape + [(1, 1), (1, 2), (2, 1)], # strides + [(1, 1), (2, 1)], # lhs_dils + [(1, 1), (2, 2)]) # rhs_dils + for b, i, j in itertools.product([1, 2], repeat=3)] + for lhs_shape in lhs_shapes + for strides in all_strides + for rhs_dil in rhs_dils + for lhs_dil in lhs_dils + for dtype in grad_float_dtypes + for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) + for dim_nums, perms in [ + (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), + (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] + for rng_factory in [jtu.rand_default] + )) + def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, + padding, lhs_dil, rhs_dil, dimension_numbers, + perms, feature_group_count, batch_group_count, + rng_factory): + if dtype == onp.float16: + raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve + + rng = rng_factory(self.rng()) + tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4} + + # permute shapes to match dim_spec, scale by feature_group_count + lhs_perm, rhs_perm = perms + lhs_shape = list(onp.take(lhs_shape, lhs_perm)) + rhs_shape = list(onp.take(rhs_shape, rhs_perm)) + + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + conv = partial(lax.conv_general_dilated, window_strides=strides, + padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count, + precision=lax.Precision.HIGHEST) + check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], + atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( + jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype)), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "rng_factory": jtu.rand_default} + for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] + for dtype in float_dtypes)) + def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): + rng = rng_factory(self.rng()) + tol = {onp.float16: 1e-1, onp.float32: 1e-4} + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + dot = partial(lax.dot, precision=lax.Precision.HIGHEST) + check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], + atol=tol, rtol=tol) + # check that precision config is preserved + result, pullback = api.vjp(dot, lhs, rhs) + gresult = lax.zeros_like_array(result) + s = str(api.make_jaxpr(pullback)(gresult)) + assert "precision=HIGHEST" in s + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + dimension_numbers), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small} + for lhs_shape, rhs_shape, dimension_numbers in [ + ((3, 2), (2, 4), (([1], [0]), ([], []))), + ((3, 5), (2, 5), (([1], [1]), ([], []))), + ((5, 3), (5, 2), (([0], [0]), ([], []))), + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), + ] + for dtype in float_dtypes)) + def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, + dimension_numbers, rng_factory): + rng = rng_factory(self.rng()) + lhs = rng(lhs_shape, dtype) + rhs = rng(rhs_shape, dtype) + dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, + precision=lax.Precision.HIGHEST) + check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) + # check that precision config is preserved + result, pullback = api.vjp(dot_general, lhs, rhs) + gresult = lax.zeros_like_array(result) + s = str(api.make_jaxpr(pullback)(gresult)) + assert "precision=HIGHEST" in s + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( + shape, onp.dtype(dtype).name, broadcast_sizes), + "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, + "rng_factory": rng_factory} + for shape in [(), (2, 3)] + for dtype in float_dtypes + for broadcast_sizes in [(), (2,), (1, 2)] + for rng_factory in [jtu.rand_default])) + def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): + rng = rng_factory(self.rng()) + args = (rng(shape, dtype),) + broadcast = lambda x: lax.broadcast(x, broadcast_sizes) + check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( + jtu.format_shape_dtype_string(inshape, dtype), + outshape, broadcast_dimensions), + "inshape": inshape, "dtype": dtype, "outshape": outshape, + "dimensions": broadcast_dimensions, "rng_factory": rng_factory} + for inshape, outshape, broadcast_dimensions in [ + ([2], [2, 2], [0]), + ([2], [2, 2], [1]), + ([2], [2, 3], [0]), + ([], [2, 3], []), + ] + for dtype in float_dtypes + for rng_factory in [jtu.rand_default])) + def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(inshape, dtype) + broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) + check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}_perm={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + permutation), + "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, + "rng_factory": rng_factory, "permutation": permutation} + for dtype in float_dtypes + for arg_shape, out_shape, permutation in [ + [(3, 4), (12,), None], + [(2, 1, 4), (8,), None], + [(2, 2, 4), (2, 8), None], + [(3, 4), (12,), (0, 1)], + [(3, 4), (12,), (1, 0)], + [(2, 1, 4), (8,), (0, 2, 1)], + [(2, 1, 4), (8,), (2, 0, 1)], + [(2, 2, 4), (2, 8), (0, 2, 1)], + [(2, 2, 4), (2, 8), (2, 0, 1)], + ] + for rng_factory in [jtu.rand_default])) + def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(arg_shape, dtype) + reshape = lambda x: lax.reshape(x, out_shape, permutation) + check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_pads={}" + .format(jtu.format_shape_dtype_string(shape, dtype), pads), + "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small} + for shape in [(2, 3)] + for dtype in float_dtypes + for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) + def testPadGrad(self, shape, dtype, pads, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) + check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) + + operand = rng(shape, dtype) + padding_value = onp.array(0., dtype) + pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) + check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) + + def testReverseGrad(self): + rev = lambda operand: lax.rev(operand, dimensions) + + dimensions = [0] + check_grads(rev, (onp.array([3., 2., 1.]),), 2) + + dimensions = [0, 1] + check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2, + rtol={onp.float32: 3e-3}) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_predshape={}_argshapes={}".format( + jtu.format_shape_dtype_string(pred_shape, onp.bool_), + jtu.format_shape_dtype_string(arg_shape, dtype)), + "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype, + "rng_factory": rng_factory} + for arg_shape in [(), (3,), (2, 3)] + for pred_shape in ([(), arg_shape] if arg_shape else [()]) + for dtype in float_dtypes + for rng_factory in [jtu.rand_default])) + def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): + rng = rng_factory(self.rng()) + pred = rng(pred_shape, onp.bool_) + on_true = rng(arg_shape, dtype) + on_false = rng(arg_shape, dtype) + select = lambda on_true, on_false: lax.select(pred, on_true, on_false) + check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_shape={}_start_indices={}_limit_indices={}_strides={}".format( + jtu.format_shape_dtype_string(shape, dtype), + start_indices, limit_indices, strides), + "shape": shape, "dtype": dtype, "starts": start_indices, + "limits": limit_indices, "strides": strides, "rng_factory": rng_factory} + for shape, start_indices, limit_indices, strides in [ + [(3,), (1,), (2,), None], + [(7,), (4,), (7,), None], + [(5,), (1,), (5,), (2,)], + [(8,), (1,), (6,), (2,)], + [(5, 3), (1, 1), (3, 2), None], + [(5, 3), (1, 1), (3, 1), None], + [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], + [(5, 3), (1, 1), (2, 1), (1, 1)], + [(5, 3), (1, 1), (5, 3), (2, 1)], + ] + for dtype in float_dtypes + for rng_factory in [jtu.rand_default])) + def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + slice = lambda x: lax.slice(x, starts, limits, strides) + check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( + jtu.format_shape_dtype_string(shape, dtype), + start_indices, size_indices), + "shape": shape, "dtype": dtype, "start_indices": start_indices, + "size_indices": size_indices, "rng_factory": rng_factory} + for shape, start_indices, size_indices in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + ] + for dtype in float_dtypes + for rng_factory in [jtu.rand_default])) + def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, + rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) + check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype), + start_indices, update_shape), + "shape": shape, "dtype": dtype, "start_indices": start_indices, + "update_shape": update_shape, "rng_factory": rng_factory} + for shape, start_indices, update_shape in [ + [(3,), (1,), (1,)], + [(5, 3), (1, 1), (3, 1)], + [(7, 5, 3), (4, 1, 0), (2, 0, 1)], + ] + for dtype in float_dtypes + for rng_factory in [jtu.rand_default])) + def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, + update_shape, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + update = rng(update_shape, dtype) + start_indices = onp.array(start_indices) + + dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) + check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) + + dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) + check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) + + dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) + check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_perm={}".format( + jtu.format_shape_dtype_string(shape, dtype), perm), + "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory} + for shape, perm in [ + [(3, 4), (1, 0)], + [(3, 4), (0, 1)], + [(3, 4, 5), (2, 1, 0)], + [(3, 4, 5), (1, 0, 2)], + ] + for dtype in float_dtypes + for rng_factory in [jtu.rand_default])) + def testTransposeGrad(self, shape, dtype, perm, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + transpose = lambda x: lax.transpose(x, perm) + check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_inshape={}_reducedims={}" + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), + "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, + "dims": dims, "rng_factory": rng_factory} + for init_val, op, dtypes, rng_factory in [ + (0, lax.add, inexact_dtypes, jtu.rand_default), + (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), + (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), + (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), + ] + for dtype in dtypes + for shape, dims in [ + [(3, 4, 5), ()], + [(3, 4, 5), (0,)], + [(3, 4, 5), (1, 2)], + [(3, 4, 5), (0, 2)], + [(3, 4, 5), (0, 1, 2)], + [(3, 1), (1,)], + ])) + def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): + rng = rng_factory(self.rng()) + if jtu.device_under_test() == "tpu" and op is lax.mul: + raise SkipTest("unimplemented case") + tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, + onp.float64: 1e-3, onp.complex64: 1e-1} + operand = rng(shape, dtype) + init_val = onp.asarray(init_val, dtype=dtype) + reduce = lambda operand: lax.reduce(operand, init_val, op, dims) + eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else + 1e-1 if dtype == dtypes.bfloat16 else + 1e-2 if dtypes.finfo(dtype).bits == 32 else None) + check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_dtype={}_padding={}" + .format(op.__name__, onp.dtype(dtype).name, padding), + "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, + "rng_factory": rng_factory} + for init_val, op, dtypes, rng_factory in [ + (0, lax.add, grad_float_dtypes, jtu.rand_small), + (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), + (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), + ] + for dtype in dtypes + for padding in ["VALID", "SAME"])) + @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.ignore_warning(category=UserWarning, + message="Using reduced precision for gradient.*") + def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): + rng = rng_factory(self.rng()) + init_val = onp.asarray(init_val, dtype=dtype) + + # We need this conditional and the corresponding loop logic to be in the + # test method, rather than at the parameterized test level, because it + # depends on FLAGS for the device under test. + # TODO(b/31565929): enable when fixed. + if jtu.device_under_test() == "tpu" and op is not lax.add: + all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] + + # TODO(b/73062247): need variadic reduce-window for better precision. + gradient_order = 1 + else: + all_configs = itertools.chain( + itertools.product( + [(4, 6)], # shapes + [(2, 1), (1, 2)], # window_dimensions + [(1, 1), (2, 1), (1, 2)] # strides + ), + itertools.product( + [(3, 2, 4, 6)], # shapes + [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions + [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides + ) + gradient_order = 3 + + def fun(operand): + return lax.reduce_window(operand, init_val, op, dims, strides, padding) + + for shape, dims, strides in all_configs: + operand = rng(shape, dtype) + if op is lax.add: + eps = 1. + tol = None + else: + # this test can fail if there are duplicates in operand + self.assertEqual(onp.unique(operand).size, operand.size, + msg="test requires operand elements to be unique.") + eps = 1e-2 + tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2} + check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, + eps) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_shape={}_axis={}" + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis), + "op": op, "shape": shape, "dtype": dtype, + "axis": axis, "rng_factory": rng_factory} + for op, types in [ + (lax.cumsum, [onp.float32, onp.float64]), + (lax.cumprod, [onp.float32, onp.float64]), + ] + for dtype in types + for shape in [[10], [3, 4, 5]] + for axis in range(len(shape)) + for rng_factory in [ + jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) + else jtu.rand_small])) + def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): + rng = rng_factory(self.rng()) + check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2) + + + # TODO(b/205052657): enable more tests when supported + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis} + for dtype in [onp.float32] + for shape in [(5,), (5, 7)] + for axis in [len(shape) - 1] + for rng_factory in [jtu.rand_default])) + def testSortGrad(self, shape, dtype, axis, rng_factory): + rng = rng_factory(self.rng()) + operand = rng(shape, dtype) + sort = lambda x: lax.sort(x, dimension=axis) + check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) + + # TODO(b/205052657): enable more tests when supported + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_keyshape={}_valshape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, key_dtype), + jtu.format_shape_dtype_string(shape, val_dtype), + axis), + "rng_factory": rng_factory, "shape": shape, + "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis} + for key_dtype in [onp.float32] + for val_dtype in [onp.float32] + for shape in [(3,), (5, 3)] + for axis in [len(shape) - 1] + for rng_factory in [jtu.rand_default])) + def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, rng_factory): + rng = rng_factory(self.rng()) + # This test relies on the property that wherever keys are tied, values are + # too, since we don't guarantee the same ordering of values with equal keys. + # To avoid that case, we generate unique keys (globally in the key array). + def args_maker(): + flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype) + keys = self.rng().permutation(flat_keys).reshape(shape) + values = rng(shape, val_dtype) + return keys, values + keys, values = args_maker() + + fun = lambda keys, values: lax.sort_key_val(keys, values, axis) + check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} + for dtype in [onp.float32,] + for shape in [(4,), (5, 5), (2, 1, 4)] + for k in [1, 3] + for rng_factory in [jtu.rand_default])) + def testTopKGrad(self, shape, dtype, k, rng_factory): + flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype) + values = self.rng().permutation(flat_values).reshape(shape) + fun = lambda vs: lax.top_k(vs, k=k)[0] + check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), idxs, axes), + "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, + "rng_factory": rng_factory} + for dtype in float_dtypes + for shape, idxs, axes in [ + [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)], + [(3, 4, 5), (onp.array([-1, -2]),), (0,)], + [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)], + [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)], + ] + for rng_factory in [jtu.rand_default])) + def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): + rng = rng_factory(self.rng()) + src = rng(shape, dtype) + index_take = lambda src: lax.index_take(src, idxs, axes) + check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( + jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, + slice_sizes), + "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, + "slice_sizes": slice_sizes, "rng_factory": rng_factory, + "rng_idx_factory": rng_idx_factory} + for dtype in grad_float_dtypes + for shape, idxs, dnums, slice_sizes, max_idx in [ + ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), + (1,), 5), + ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), + (2,), 9), + ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), + (1, 3), 3), + ] + for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] + for rng_factory in [jtu.rand_default])) + def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory, + rng_idx_factory): + rng = rng_factory(self.rng()) + rng_idx = rng_idx_factory(self.rng()) + idxs = rng_idx(idxs.shape, idxs.dtype) + gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, + slice_sizes=slice_sizes) + x = rng(shape, dtype) + check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + idxs, update_shape, dnums), + "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, + "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, + "rng_idx_factory": rng_idx_factory} + for dtype in grad_float_dtypes + for arg_shape, idxs, update_shape, dnums, max_idx in [ + ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), 4), + ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,)), 9), + ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), 3), + ] + for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] + for rng_factory in [jtu.rand_default])) + def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, + rng_factory, rng_idx_factory): + rng = rng_factory(self.rng()) + rng_idx = rng_idx_factory(self.rng()) + idxs = rng_idx(idxs.shape, idxs.dtype) + scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, + dimension_numbers=dnums) + x = rng(arg_shape, dtype) + y = rng(update_shape, dtype) + check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + idxs, update_shape, dnums), + "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, + "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, + "rng_idx_factory": rng_idx_factory} + for dtype in grad_float_dtypes + for arg_shape, idxs, update_shape, dnums, max_idx in [ + ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), 4), + ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,)), 9), + ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), 3), + ] + # Scatters with conflicting indices are not deterministic on GPU, so we + # use indices that do not collide. + for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)] + for rng_factory in [jtu.rand_default])) + def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, + rng_factory, rng_idx_factory): + rng = rng_factory(self.rng()) + rng_idx = rng_idx_factory(self.rng()) + idxs = rng_idx(idxs.shape, idxs.dtype) + scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) + x = rng(arg_shape, dtype) + y = rng(update_shape, dtype) + check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) + + def testScatterGradSymbolicZeroUpdate(self): + # https://github.com/google/jax/issues/1901 + def f(x): + n = x.shape[0] + y = onp.arange(n, dtype=x.dtype) + return jax.ops.index_update(x, onp.diag_indices(n), y) + rng = jtu.rand_default(self.rng()) + check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, + 1.) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + idxs, update_shape, dnums), + "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, + "update_shape": update_shape, "dnums": dnums, + "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} + for dtype in grad_float_dtypes + for arg_shape, idxs, update_shape, dnums in [ + ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ] + for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] + for rng_factory in [jtu.rand_default])) + def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, + rng_factory, rng_idx_factory): + rng = rng_factory(self.rng()) + rng_idx = rng_idx_factory(self.rng()) + idxs = rng_idx(idxs.shape, idxs.dtype) + scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) + x = rng(arg_shape, dtype) + y = rng(update_shape, dtype) + check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + idxs, update_shape, dnums), + "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, + "update_shape": update_shape, "dnums": dnums, + "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} + for dtype in grad_float_dtypes + for arg_shape, idxs, update_shape, dnums in [ + ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ] + for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] + for rng_factory in [jtu.rand_default])) + def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums, + rng_factory, rng_idx_factory): + rng = rng_factory(self.rng()) + rng_idx = rng_idx_factory(self.rng()) + idxs = rng_idx(idxs.shape, idxs.dtype) + scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) + x = rng(arg_shape, dtype) + y = rng(update_shape, dtype) + check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) + + def testStopGradient(self): + def f(x): + return lax.sin(x) * lax.cos(lax.stop_gradient(x)) + + def f2(x, y): + return lax.sin(x) * lax.cos(y) + + x = 3.14 + ans = api.grad(f)(x) + expected = api.grad(f2)(x, x) + self.assertAllClose(ans, expected) + + ans = api.grad(api.grad(f))(x) + expected = api.grad(api.grad(f2))(x, x) + self.assertAllClose(ans, expected) + + ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) + expected = onp.array(0.0) + self.assertAllClose(ans, expected, check_dtypes=False) + + with core.skipping_checks(): + with self.assertRaises(TypeError): + lax.stop_gradient(lambda x: x) + + # TODO(mattjj): make this a more systematic test + def testRemainder(self): + rng = onp.random.RandomState(0) + x = rng.uniform(-0.9, 9, size=(3, 4)) + y = rng.uniform(0.7, 1.9, size=(3, 1)) + assert not set(onp.unique(x)) & set(onp.unique(y)) + tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 + check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) + + rng = onp.random.RandomState(0) + x = rng.uniform(-0.9, 9, size=(1, 4)) + y = rng.uniform(0.7, 1.9, size=(3, 4)) + assert not set(onp.unique(x)) & set(onp.unique(y)) + tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 + check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) + + def testHigherOrderGradientOfReciprocal(self): + # Regression test for https://github.com/google/jax/issues/3136 + def inv(x): + # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) + return 1 / x + grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) + self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.))) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/lax_test.py b/tests/lax_test.py index e60fd8f4bcdb..e13081041f2a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -16,7 +16,6 @@ import collections from functools import partial import itertools -from typing import Optional, cast from unittest import SkipTest from absl.testing import absltest @@ -32,7 +31,6 @@ from jax import test_util as jtu from jax import lax_reference from jax.test_util import check_grads -from jax.lib import xla_client import jax.util from jax.config import config @@ -1783,1648 +1781,5 @@ def testDeltaConstant(self, dtype, shape, axes): self._Check(make_const, expected) -GradTestSpec = collections.namedtuple( - "GradTestSpec", - ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) -def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): - return GradTestSpec( - op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) - -grad_float_dtypes = list(jtu.supported_dtypes().intersection( - {onp.float32, onp.float64})) -grad_complex_dtypes = list(jtu.supported_dtypes().intersection( - {onp.complex64, onp.complex128})) -grad_inexact_dtypes = grad_float_dtypes + grad_complex_dtypes - -LAX_GRAD_OPS = [ - grad_test_spec(lax.neg, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.floor, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4), - dtypes=grad_float_dtypes), - grad_test_spec(lax.ceil, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4), - dtypes=grad_float_dtypes), - grad_test_spec(lax.round, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4), - dtypes=grad_float_dtypes), - - grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.log1p, nargs=1, order=2, rng_factory=jtu.rand_positive, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.sinh, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes + [onp.complex64], tol=1e-5), - grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes, tol=1e-5), - grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes, tol=1e-5), - grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes, tol={onp.float32: 5e-1}), - grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.tan, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3), - dtypes=grad_inexact_dtypes, tol=1e-3), - grad_test_spec(lax.asin, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3), - dtypes=grad_float_dtypes, tol=1e-3), - grad_test_spec(lax.acos, nargs=1, order=2, - rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3), - dtypes=grad_float_dtypes, tol=2e-2), - # TODO(proteneer): atan2 input is already a representation of a - # complex number. Need to think harder about what this even means - # if each input itself is a complex number. - grad_test_spec(lax.atan2, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes), - - grad_test_spec(lax.erf, nargs=1, order=2, rng_factory=jtu.rand_small, - dtypes=grad_float_dtypes), - grad_test_spec(lax.erfc, nargs=1, order=2, rng_factory=jtu.rand_small, - dtypes=grad_float_dtypes), - grad_test_spec(lax.erf_inv, nargs=1, order=2, rng_factory=jtu.rand_small, - dtypes=grad_float_dtypes), - # grad_test_spec(lax.lgamma, nargs=1, order=2, rng_factory=jtu.rand_small, - # dtypes=grad_float_dtypes), # TODO(mattjj): enable - grad_test_spec(lax.bessel_i0e, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes), - grad_test_spec(lax.bessel_i1e, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes), - - grad_test_spec(lax.real, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_complex_dtypes), - grad_test_spec(lax.imag, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_complex_dtypes), - grad_test_spec(lax.complex, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes), - grad_test_spec(lax.conj, nargs=1, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.abs, nargs=1, order=2, rng_factory=jtu.rand_positive, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.pow, nargs=2, order=2, rng_factory=jtu.rand_positive, - dtypes=grad_inexact_dtypes, tol={onp.float32: 3e-1}), - - grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.sub, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.mul, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_inexact_dtypes), - grad_test_spec(lax.div, nargs=2, order=1, rng_factory=jtu.rand_not_small, - dtypes=grad_inexact_dtypes), - - grad_test_spec(lax.max, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes), - grad_test_spec(lax.min, nargs=2, order=2, rng_factory=jtu.rand_default, - dtypes=grad_float_dtypes), - # TODO(mattjj): make some-equal checks more robust, enable second-order - # grad_test_spec(lax.max, nargs=2, order=1, rng_factory=jtu.rand_some_equal, - # dtypes=grad_float_dtypes, name="MaxSomeEqual"), - # grad_test_spec(lax.min, nargs=2, order=1, rng_factory=jtu.rand_some_equal, - # dtypes=grad_float_dtypes, name="MinSomeEqual"), -] - -GradSpecialValuesTestSpec = collections.namedtuple( - "GradSpecialValuesTestSpec", ["op", "values", "tol"]) -def grad_special_values_test_spec(op, values, tol=None): - return GradSpecialValuesTestSpec(op, values, tol) - -LAX_GRAD_SPECIAL_VALUE_TESTS = [ - grad_special_values_test_spec( - lax.sinh, [0.], - tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None), - grad_special_values_test_spec( - lax.cosh, [0.], - tol={onp.float32: 1e-2} if jtu.device_under_test() == "tpu" else None), - grad_special_values_test_spec(lax.tanh, [0., 1000.]), - grad_special_values_test_spec(lax.sin, [0., onp.pi, onp.pi/2., onp.pi/4.]), - grad_special_values_test_spec(lax.cos, [0., onp.pi, onp.pi/2., onp.pi/4.]), - grad_special_values_test_spec(lax.tan, [0.]), - grad_special_values_test_spec(lax.asin, [0.]), - grad_special_values_test_spec(lax.acos, [0.]), - grad_special_values_test_spec(lax.atan, [0., 1000.]), - grad_special_values_test_spec(lax.erf, [0., 10.]), - grad_special_values_test_spec(lax.erfc, [0., 10.]), -] - - -def check_grads_bilinear(f, args, order, - modes=["fwd", "rev"], atol=None, rtol=None): - # Can use large eps to make up for numerical inaccuracies since the op is - # bilinear (relying on the fact that we only check one arg at a time) - lhs, rhs = args - check_grads(lambda lhs: f(lhs, rhs), (lhs,), order, - modes=modes, atol=atol, rtol=rtol, eps=1.) - check_grads(lambda rhs: f(lhs, rhs), (rhs,), order, - modes=modes, atol=atol, rtol=rtol, eps=1.) - -class LaxAutodiffTest(jtu.JaxTestCase): - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.name, shapes, itertools.repeat(dtype)), - "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, - "order": rec.order, "tol": rec.tol} - for shape_group in compatible_shapes - for shapes in CombosWithReplacement(shape_group, rec.nargs) - for dtype in rec.dtypes) - for rec in LAX_GRAD_OPS)) - def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): - rng = rng_factory(self.rng()) - if jtu.device_under_test() == "tpu" and op is lax.pow: - raise SkipTest("pow grad imprecise on tpu") - tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol - args = tuple(rng(shape, dtype) for shape in shapes) - check_grads(op, args, order, ["fwd", "rev"], tol, tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), - "op": rec.op, "special_value": special_value, "tol": rec.tol} - for special_value in rec.values) - for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) - def testOpGradSpecialValue(self, op, special_value, tol): - check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_from_dtype={}_to_dtype={}".format( - jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), - "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory} - for from_dtype, to_dtype in itertools.product( - float_dtypes + complex_dtypes, repeat=2) - for rng_factory in [jtu.rand_default])) - def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): - rng = rng_factory(self.rng()) - tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), - jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) - args = (rng((2, 3), from_dtype),) - convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) - convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)( - convert_element_type) - check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype, - "rng_factory": rng_factory} - for shape in [(), (2, 3)] - for dtype in grad_float_dtypes - for rng_factory in [jtu.rand_default])) - def testClampGrad(self, shape, dtype, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - low = operand - dtype(10) - high = operand + dtype(10) - # Avoids points near the boundary where the gradient may be inaccurate. - check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) - check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) - check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( - dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name, - num_arrs), - "dim": dim, "base_shape": base_shape, "dtype": dtype, - "num_arrs": num_arrs, "rng_factory": rng_factory} - for num_arrs in [3] - for dtype in float_dtypes - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for dim in range(len(base_shape)) - for rng_factory in [jtu.rand_default])) - def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): - rng = rng_factory(self.rng()) - shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] - operands = tuple(rng(shape, dtype) for shape in shapes) - concatenate = lambda *args: lax.concatenate(args, dim) - check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding, "rng_factory": rng_factory,} - for lhs_shape, rhs_shape, all_strides in itertools.chain( - [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) - for b, i, j in itertools.product([2, 3], repeat=3)], - [((4, 2, 1), (3, 2, 1), [(1,)])]) - for strides in all_strides - for dtype in float_dtypes - for padding in ["VALID", "SAME"] - for rng_factory in [jtu.rand_small])) - def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): - rng = rng_factory(self.rng()) - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - conv = partial(lax.conv, window_strides=strides, padding=padding, - precision=lax.Precision.HIGHEST) - check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], - atol=1e-2, rtol=1e-2) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" - "rhs_dilation={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding, lhs_dil, rhs_dil), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding, "lhs_dil": lhs_dil, - "rhs_dil": rhs_dil, "rng_factory": rng_factory} - for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in - itertools.chain( - [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], - [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], - [(1, 1), (2, 1)], [(1, 1)]) - for b, i, j in itertools.product([2, 3], repeat=3)], - [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], - [(1,), (2,)], [(1,), (2,)])]) - for strides in all_strides - for rhs_dil in rhs_dils - for lhs_dil in lhs_dils - for dtype in float_dtypes - for padding in all_pads - for rng_factory in [jtu.rand_small])) - def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, - padding, lhs_dil, rhs_dil, rng_factory): - rng = rng_factory(self.rng()) - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - conv = partial(lax.conv_with_general_padding, window_strides=strides, - padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, - precision=lax.Precision.HIGHEST) - check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], - atol=1e-2, rtol=1e-2) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" - "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), - feature_group_count, batch_group_count), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding, "lhs_dil": lhs_dil, - "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, - "perms": perms, "feature_group_count": feature_group_count, - "batch_group_count": batch_group_count} - for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) - for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ - ([(b * batch_group_count, i * feature_group_count, 6, 7), - (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape - (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape - [(1, 1), (1, 2), (2, 1)], # strides - [(1, 1), (2, 1)], # lhs_dils - [(1, 1), (2, 2)]) # rhs_dils - for b, i, j in itertools.product([1, 2], repeat=3)] - for lhs_shape in lhs_shapes - for strides in all_strides - for rhs_dil in rhs_dils - for lhs_dil in lhs_dils - for dtype in grad_float_dtypes - for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + - ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) - for dim_nums, perms in [ - (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), - (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), - (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] - for rng_factory in [jtu.rand_default] - )) - def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, - padding, lhs_dil, rhs_dil, dimension_numbers, - perms, feature_group_count, batch_group_count, - rng_factory): - if dtype == onp.float16: - raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve - - rng = rng_factory(self.rng()) - tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4} - - # permute shapes to match dim_spec, scale by feature_group_count - lhs_perm, rhs_perm = perms - lhs_shape = list(onp.take(lhs_shape, lhs_perm)) - rhs_shape = list(onp.take(rhs_shape, rhs_perm)) - - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - conv = partial(lax.conv_general_dilated, window_strides=strides, - padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, - batch_group_count=batch_group_count, - precision=lax.Precision.HIGHEST) - check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], - atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( - jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype)), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "rng_factory": jtu.rand_default} - for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] - for dtype in float_dtypes)) - def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): - rng = rng_factory(self.rng()) - tol = {onp.float16: 1e-1, onp.float32: 1e-4} - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - dot = partial(lax.dot, precision=lax.Precision.HIGHEST) - check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], - atol=tol, rtol=tol) - # check that precision config is preserved - result, pullback = api.vjp(dot, lhs, rhs) - gresult = lax.zeros_like_array(result) - s = str(api.make_jaxpr(pullback)(gresult)) - assert "precision=HIGHEST" in s - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small} - for lhs_shape, rhs_shape, dimension_numbers in [ - ((3, 2), (2, 4), (([1], [0]), ([], []))), - ((3, 5), (2, 5), (([1], [1]), ([], []))), - ((5, 3), (5, 2), (([0], [0]), ([], []))), - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), - ] - for dtype in float_dtypes)) - def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, - dimension_numbers, rng_factory): - rng = rng_factory(self.rng()) - lhs = rng(lhs_shape, dtype) - rhs = rng(rhs_shape, dtype) - dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, - precision=lax.Precision.HIGHEST) - check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) - # check that precision config is preserved - result, pullback = api.vjp(dot_general, lhs, rhs) - gresult = lax.zeros_like_array(result) - s = str(api.make_jaxpr(pullback)(gresult)) - assert "precision=HIGHEST" in s - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( - shape, onp.dtype(dtype).name, broadcast_sizes), - "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, - "rng_factory": rng_factory} - for shape in [(), (2, 3)] - for dtype in float_dtypes - for broadcast_sizes in [(), (2,), (1, 2)] - for rng_factory in [jtu.rand_default])) - def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): - rng = rng_factory(self.rng()) - args = (rng(shape, dtype),) - broadcast = lambda x: lax.broadcast(x, broadcast_sizes) - check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( - jtu.format_shape_dtype_string(inshape, dtype), - outshape, broadcast_dimensions), - "inshape": inshape, "dtype": dtype, "outshape": outshape, - "dimensions": broadcast_dimensions, "rng_factory": rng_factory} - for inshape, outshape, broadcast_dimensions in [ - ([2], [2, 2], [0]), - ([2], [2, 2], [1]), - ([2], [2, 3], [0]), - ([], [2, 3], []), - ] - for dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(inshape, dtype) - broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) - check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_perm={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype), - permutation), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "rng_factory": rng_factory, "permutation": permutation} - for dtype in float_dtypes - for arg_shape, out_shape, permutation in [ - [(3, 4), (12,), None], - [(2, 1, 4), (8,), None], - [(2, 2, 4), (2, 8), None], - [(3, 4), (12,), (0, 1)], - [(3, 4), (12,), (1, 0)], - [(2, 1, 4), (8,), (0, 2, 1)], - [(2, 1, 4), (8,), (2, 0, 1)], - [(2, 2, 4), (2, 8), (0, 2, 1)], - [(2, 2, 4), (2, 8), (2, 0, 1)], - ] - for rng_factory in [jtu.rand_default])) - def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(arg_shape, dtype) - reshape = lambda x: lax.reshape(x, out_shape, permutation) - check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_pads={}" - .format(jtu.format_shape_dtype_string(shape, dtype), pads), - "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small} - for shape in [(2, 3)] - for dtype in float_dtypes - for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) - def testPadGrad(self, shape, dtype, pads, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) - check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) - - operand = rng(shape, dtype) - padding_value = onp.array(0., dtype) - pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) - check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) - - def testReverseGrad(self): - rev = lambda operand: lax.rev(operand, dimensions) - - dimensions = [0] - check_grads(rev, (onp.array([3., 2., 1.]),), 2) - - dimensions = [0, 1] - check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2, - rtol={onp.float32: 3e-3}) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_predshape={}_argshapes={}".format( - jtu.format_shape_dtype_string(pred_shape, onp.bool_), - jtu.format_shape_dtype_string(arg_shape, dtype)), - "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype, - "rng_factory": rng_factory} - for arg_shape in [(), (3,), (2, 3)] - for pred_shape in ([(), arg_shape] if arg_shape else [()]) - for dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): - rng = rng_factory(self.rng()) - pred = rng(pred_shape, onp.bool_) - on_true = rng(arg_shape, dtype) - on_false = rng(arg_shape, dtype) - select = lambda on_true, on_false: lax.select(pred, on_true, on_false) - check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_start_indices={}_limit_indices={}_strides={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, limit_indices, strides), - "shape": shape, "dtype": dtype, "starts": start_indices, - "limits": limit_indices, "strides": strides, "rng_factory": rng_factory} - for shape, start_indices, limit_indices, strides in [ - [(3,), (1,), (2,), None], - [(7,), (4,), (7,), None], - [(5,), (1,), (5,), (2,)], - [(8,), (1,), (6,), (2,)], - [(5, 3), (1, 1), (3, 2), None], - [(5, 3), (1, 1), (3, 1), None], - [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], - [(5, 3), (1, 1), (2, 1), (1, 1)], - [(5, 3), (1, 1), (5, 3), (2, 1)], - ] - for dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - slice = lambda x: lax.slice(x, starts, limits, strides) - check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, size_indices), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "size_indices": size_indices, "rng_factory": rng_factory} - for shape, start_indices, size_indices in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - ] - for dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, - rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) - check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, update_shape), - "shape": shape, "dtype": dtype, "start_indices": start_indices, - "update_shape": update_shape, "rng_factory": rng_factory} - for shape, start_indices, update_shape in [ - [(3,), (1,), (1,)], - [(5, 3), (1, 1), (3, 1)], - [(7, 5, 3), (4, 1, 0), (2, 0, 1)], - ] - for dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, - update_shape, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - update = rng(update_shape, dtype) - start_indices = onp.array(start_indices) - - dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) - check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) - - dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) - check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) - - dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) - check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_perm={}".format( - jtu.format_shape_dtype_string(shape, dtype), perm), - "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory} - for shape, perm in [ - [(3, 4), (1, 0)], - [(3, 4), (0, 1)], - [(3, 4, 5), (2, 1, 0)], - [(3, 4, 5), (1, 0, 2)], - ] - for dtype in float_dtypes - for rng_factory in [jtu.rand_default])) - def testTransposeGrad(self, shape, dtype, perm, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - transpose = lambda x: lax.transpose(x, perm) - check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_inshape={}_reducedims={}" - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), - "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, - "dims": dims, "rng_factory": rng_factory} - for init_val, op, dtypes, rng_factory in [ - (0, lax.add, inexact_dtypes, jtu.rand_default), - (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), - (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), - (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), - ] - for dtype in dtypes - for shape, dims in [ - [(3, 4, 5), ()], - [(3, 4, 5), (0,)], - [(3, 4, 5), (1, 2)], - [(3, 4, 5), (0, 2)], - [(3, 4, 5), (0, 1, 2)], - [(3, 1), (1,)], - ])) - def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): - rng = rng_factory(self.rng()) - if jtu.device_under_test() == "tpu" and op is lax.mul: - raise SkipTest("unimplemented case") - tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, - onp.float64: 1e-3, onp.complex64: 1e-1} - operand = rng(shape, dtype) - init_val = onp.asarray(init_val, dtype=dtype) - reduce = lambda operand: lax.reduce(operand, init_val, op, dims) - eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else - 1e-1 if dtype == dtypes.bfloat16 else - 1e-2 if dtypes.finfo(dtype).bits == 32 else None) - check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}_padding={}" - .format(op.__name__, onp.dtype(dtype).name, padding), - "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, - "rng_factory": rng_factory} - for init_val, op, dtypes, rng_factory in [ - (0, lax.add, grad_float_dtypes, jtu.rand_small), - (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), - (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), - ] - for dtype in dtypes - for padding in ["VALID", "SAME"])) - @jtu.skip_on_flag("jax_skip_slow_tests", True) - @jtu.ignore_warning(category=UserWarning, - message="Using reduced precision for gradient.*") - def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): - rng = rng_factory(self.rng()) - init_val = onp.asarray(init_val, dtype=dtype) - - # We need this conditional and the corresponding loop logic to be in the - # test method, rather than at the parameterized test level, because it - # depends on FLAGS for the device under test. - # TODO(b/31565929): enable when fixed. - if jtu.device_under_test() == "tpu" and op is not lax.add: - all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] - - # TODO(b/73062247): need variadic reduce-window for better precision. - gradient_order = 1 - else: - all_configs = itertools.chain( - itertools.product( - [(4, 6)], # shapes - [(2, 1), (1, 2)], # window_dimensions - [(1, 1), (2, 1), (1, 2)] # strides - ), - itertools.product( - [(3, 2, 4, 6)], # shapes - [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions - [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides - ) - gradient_order = 3 - - def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding) - - for shape, dims, strides in all_configs: - operand = rng(shape, dtype) - if op is lax.add: - eps = 1. - tol = None - else: - # this test can fail if there are duplicates in operand - self.assertEqual(onp.unique(operand).size, operand.size, - msg="test requires operand elements to be unique.") - eps = 1e-2 - tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2} - check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, - eps) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_shape={}_axis={}" - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis), - "op": op, "shape": shape, "dtype": dtype, - "axis": axis, "rng_factory": rng_factory} - for op, types in [ - (lax.cumsum, [onp.float32, onp.float64]), - (lax.cumprod, [onp.float32, onp.float64]), - ] - for dtype in types - for shape in [[10], [3, 4, 5]] - for axis in range(len(shape)) - for rng_factory in [ - jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) - else jtu.rand_small])) - def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): - rng = rng_factory(self.rng()) - check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2) - - - # TODO(b/205052657): enable more tests when supported - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis} - for dtype in [onp.float32] - for shape in [(5,), (5, 7)] - for axis in [len(shape) - 1] - for rng_factory in [jtu.rand_default])) - def testSortGrad(self, shape, dtype, axis, rng_factory): - rng = rng_factory(self.rng()) - operand = rng(shape, dtype) - sort = lambda x: lax.sort(x, dimension=axis) - check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) - - # TODO(b/205052657): enable more tests when supported - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_keyshape={}_valshape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, key_dtype), - jtu.format_shape_dtype_string(shape, val_dtype), - axis), - "rng_factory": rng_factory, "shape": shape, - "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis} - for key_dtype in [onp.float32] - for val_dtype in [onp.float32] - for shape in [(3,), (5, 3)] - for axis in [len(shape) - 1] - for rng_factory in [jtu.rand_default])) - def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, rng_factory): - rng = rng_factory(self.rng()) - # This test relies on the property that wherever keys are tied, values are - # too, since we don't guarantee the same ordering of values with equal keys. - # To avoid that case, we generate unique keys (globally in the key array). - def args_maker(): - flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype) - keys = self.rng().permutation(flat_keys).reshape(shape) - values = rng(shape, val_dtype) - return keys, values - keys, values = args_maker() - - fun = lambda keys, values: lax.sort_key_val(keys, values, axis) - check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} - for dtype in [onp.float32,] - for shape in [(4,), (5, 5), (2, 1, 4)] - for k in [1, 3] - for rng_factory in [jtu.rand_default])) - def testTopKGrad(self, shape, dtype, k, rng_factory): - flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype) - values = self.rng().permutation(flat_values).reshape(shape) - fun = lambda vs: lax.top_k(vs, k=k)[0] - check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), idxs, axes), - "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, - "rng_factory": rng_factory} - for dtype in float_dtypes - for shape, idxs, axes in [ - [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)], - [(3, 4, 5), (onp.array([-1, -2]),), (0,)], - [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)], - [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)], - ] - for rng_factory in [jtu.rand_default])) - def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): - rng = rng_factory(self.rng()) - src = rng(shape, dtype) - index_take = lambda src: lax.index_take(src, idxs, axes) - check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( - jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, - slice_sizes), - "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, - "slice_sizes": slice_sizes, "rng_factory": rng_factory, - "rng_idx_factory": rng_idx_factory} - for dtype in grad_float_dtypes - for shape, idxs, dnums, slice_sizes, max_idx in [ - ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( - offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), - (1,), 5), - ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), - (2,), 9), - ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), - (1, 3), 3), - ] - for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] - for rng_factory in [jtu.rand_default])) - def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory, - rng_idx_factory): - rng = rng_factory(self.rng()) - rng_idx = rng_idx_factory(self.rng()) - idxs = rng_idx(idxs.shape, idxs.dtype) - gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, - slice_sizes=slice_sizes) - x = rng(shape, dtype) - check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, - "rng_idx_factory": rng_idx_factory} - for dtype in grad_float_dtypes - for arg_shape, idxs, update_shape, dnums, max_idx in [ - ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,)), 4), - ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0,)), 9), - ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,)), 3), - ] - for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] - for rng_factory in [jtu.rand_default])) - def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, - rng_factory, rng_idx_factory): - rng = rng_factory(self.rng()) - rng_idx = rng_idx_factory(self.rng()) - idxs = rng_idx(idxs.shape, idxs.dtype) - scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, - dimension_numbers=dnums) - x = rng(arg_shape, dtype) - y = rng(update_shape, dtype) - check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, - "rng_idx_factory": rng_idx_factory} - for dtype in grad_float_dtypes - for arg_shape, idxs, update_shape, dnums, max_idx in [ - ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,)), 4), - ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0,)), 9), - ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,)), 3), - ] - # Scatters with conflicting indices are not deterministic on GPU, so we - # use indices that do not collide. - for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)] - for rng_factory in [jtu.rand_default])) - def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, - rng_factory, rng_idx_factory): - rng = rng_factory(self.rng()) - rng_idx = rng_idx_factory(self.rng()) - idxs = rng_idx(idxs.shape, idxs.dtype) - scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) - x = rng(arg_shape, dtype) - y = rng(update_shape, dtype) - check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) - - def testScatterGradSymbolicZeroUpdate(self): - # https://github.com/google/jax/issues/1901 - def f(x): - n = x.shape[0] - y = onp.arange(n, dtype=x.dtype) - return jax.ops.index_update(x, onp.diag_indices(n), y) - rng = jtu.rand_default(self.rng()) - check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, - 1.) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums, - "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} - for dtype in grad_float_dtypes - for arg_shape, idxs, update_shape, dnums in [ - ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0,))), - ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ] - for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] - for rng_factory in [jtu.rand_default])) - def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, - rng_factory, rng_idx_factory): - rng = rng_factory(self.rng()) - rng_idx = rng_idx_factory(self.rng()) - idxs = rng_idx(idxs.shape, idxs.dtype) - scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) - x = rng(arg_shape, dtype) - y = rng(update_shape, dtype) - check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums, - "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} - for dtype in grad_float_dtypes - for arg_shape, idxs, update_shape, dnums in [ - ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0,))), - ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ] - for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] - for rng_factory in [jtu.rand_default])) - def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums, - rng_factory, rng_idx_factory): - rng = rng_factory(self.rng()) - rng_idx = rng_idx_factory(self.rng()) - idxs = rng_idx(idxs.shape, idxs.dtype) - scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) - x = rng(arg_shape, dtype) - y = rng(update_shape, dtype) - check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) - - def testStopGradient(self): - def f(x): - return lax.sin(x) * lax.cos(lax.stop_gradient(x)) - - def f2(x, y): - return lax.sin(x) * lax.cos(y) - - x = 3.14 - ans = api.grad(f)(x) - expected = api.grad(f2)(x, x) - self.assertAllClose(ans, expected) - - ans = api.grad(api.grad(f))(x) - expected = api.grad(api.grad(f2))(x, x) - self.assertAllClose(ans, expected) - - ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) - expected = onp.array(0.0) - self.assertAllClose(ans, expected, check_dtypes=False) - - with core.skipping_checks(): - with self.assertRaises(TypeError): - lax.stop_gradient(lambda x: x) - - # TODO(mattjj): make this a more systematic test - def testRemainder(self): - rng = onp.random.RandomState(0) - x = rng.uniform(-0.9, 9, size=(3, 4)) - y = rng.uniform(0.7, 1.9, size=(3, 1)) - assert not set(onp.unique(x)) & set(onp.unique(y)) - tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 - check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) - - rng = onp.random.RandomState(0) - x = rng.uniform(-0.9, 9, size=(1, 4)) - y = rng.uniform(0.7, 1.9, size=(3, 4)) - assert not set(onp.unique(x)) & set(onp.unique(y)) - tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 - check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) - - def testHigherOrderGradientOfReciprocal(self): - # Regression test for https://github.com/google/jax/issues/3136 - def inv(x): - # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) - return 1 / x - grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) - self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.))) - -def all_bdims(*shapes): - bdims = (itertools.chain([cast(Optional[int], None)], - range(len(shape) + 1)) for shape in shapes) - return (t for t in itertools.product(*bdims) if not all(e is None for e in t)) - -def add_bdim(bdim_size, bdim, shape): - shape = list(shape) - if bdim is not None: - shape.insert(bdim, bdim_size) - return tuple(shape) - -def slicer(x, bdim): - if bdim is None: - return lambda _: x - else: - return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False) - -def args_slicer(args, bdims): - slicers = list(map(slicer, args, bdims)) - return lambda i: [sl(i) for sl in slicers] - -class LaxVmapTest(jtu.JaxTestCase): - - def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, - rtol=None, atol=None): - batched_shapes = list(jax.util.safe_map(partial(add_bdim, bdim_size), - bdims, shapes)) - args = [rng(shape, dtype) - for shape, dtype in jax.util.safe_zip(batched_shapes, dtypes)] - args_slice = args_slicer(args, bdims) - ans = api.vmap(op, bdims)(*args) - expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)]) - self.assertAllClose(ans, expected, rtol=rtol, atol=atol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_bdims={}".format( - jtu.format_test_name_suffix(rec.op, shapes, - itertools.repeat(dtype)), bdims), - "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, - "dtype": dtype, "bdims": bdims, "tol": rec.tol} - for shape_group in compatible_shapes - for shapes in CombosWithReplacement(shape_group, rec.nargs) - for bdims in all_bdims(*shapes) - for dtype in rec.dtypes) - for rec in LAX_OPS)) - def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): - rng = rng_factory(self.rng()) - op = getattr(lax, op_name) - self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, - atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" - "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" - "_lhs_bdim={}_rhs_bdim={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), - feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "strides": strides, "padding": padding, "lhs_dil": lhs_dil, - "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, - "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, - "feature_group_count": feature_group_count, - "batch_group_count": batch_group_count, - } - for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) - for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ - ((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape - (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape - [(1, 1), (1, 2), (2, 1)], # strides - [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads - [(1, 1), (2, 1)], # lhs_dils - [(1, 1), (2, 2)]) # rhs_dils - for b, i, j in itertools.product([1, 2], repeat=3)] - for strides in all_strides - for rhs_dil in rhs_dils - for lhs_dil in lhs_dils - for dtype in [onp.float32] - for padding in all_pads - for dim_nums, perms in [ - (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), - (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), - (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] - for lhs_bdim in itertools.chain([cast(Optional[int], None)], - range(len(lhs_shape) + 1)) - for rhs_bdim in itertools.chain([cast(Optional[int], None)], - range(len(rhs_shape) + 1)) - if (lhs_bdim, rhs_bdim) != (None, None) - for rng_factory in [jtu.rand_default] - )) - def testConvGeneralDilatedBatching( - self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, - dimension_numbers, perms, feature_group_count, batch_group_count, - lhs_bdim, rhs_bdim, rng_factory): - rng = rng_factory(self.rng()) - tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 - - # permute shapes to match dim_spec, scale by feature_group_count - lhs_perm, rhs_perm = perms - lhs_shape = list(onp.take(lhs_shape, lhs_perm)) - rhs_shape = list(onp.take(rhs_shape, rhs_perm)) - - conv = partial(lax.conv_general_dilated, window_strides=strides, - padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, - batch_group_count=batch_group_count, - precision=lax.Precision.HIGHEST) - self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), - (dtype, dtype), rng, rtol=tol, atol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( - shape, from_dtype, to_dtype, bdims), - "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, - "bdims": bdims, "rng_factory": rng_factory} - for from_dtype, to_dtype in itertools.product( - [onp.float32, onp.int32, "float32", "int32"], repeat=2) - for shape in [(2, 3)] - for bdims in all_bdims(shape) - for rng_factory in [jtu.rand_default])) - def testConvertElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda x: lax.convert_element_type(x, to_dtype) - self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( - shape, from_dtype, to_dtype, bdims), - "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, - "bdims": bdims, "rng_factory": rng_factory} - for from_dtype, to_dtype in itertools.product( - [onp.float32, onp.int32, "float32", "int32"], repeat=2) - for shape in [(2, 3)] - for bdims in all_bdims(shape) - for rng_factory in [jtu.rand_default])) - def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda x: lax.bitcast_convert_type(x, to_dtype) - self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}" - .format(jtu.format_shape_dtype_string(min_shape, dtype), - jtu.format_shape_dtype_string(operand_shape, dtype), - jtu.format_shape_dtype_string(max_shape, dtype), - bdims), - "min_shape": min_shape, "operand_shape": operand_shape, - "max_shape": max_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory} - for min_shape, operand_shape, max_shape in [ - [(), (2, 3), ()], - [(2, 3), (2, 3), ()], - [(), (2, 3), (2, 3)], - [(2, 3), (2, 3), (2, 3)], - ] - for dtype in default_dtypes - for bdims in all_bdims(min_shape, operand_shape, max_shape) - for rng_factory in [jtu.rand_default])) - def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims, rng_factory): - rng = rng_factory(self.rng()) - raise SkipTest("batching rule for clamp not implemented") # TODO(mattj) - shapes = [min_shape, operand_shape, max_shape] - self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( - jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - bdims), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "bdims": bdims, "rng_factory": rng_factory} - for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)] - for bdims in all_bdims(lhs_shape, rhs_shape) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = partial(lax.dot, precision=lax.Precision.HIGHEST) - self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), - rng, rtol={onp.float16: 5e-2, onp.float64: 5e-14}) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - lhs_contracting, rhs_contracting, bdims), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, - "bdims": bdims, "rng_factory": rng_factory} - for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ - [(3, 5), (2, 5), [1], [1]], - [(5, 3), (5, 2), [0], [0]], - [(5, 3, 2), (5, 2, 4), [0], [0]], - [(5, 3, 2), (5, 2, 4), [0,2], [0,1]], - [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], - [(3, 2), (2, 4), [1], [0]], - ] - for bdims in all_bdims(lhs_shape, rhs_shape) - for dtype in default_dtypes - for rng_factory in [jtu.rand_small])) - def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, - lhs_contracting, rhs_contracting, bdims, rng_factory): - rng = rng_factory(self.rng()) - dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) - dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) - self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), - rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}" - .format(jtu.format_shape_dtype_string(lhs_shape, dtype), - jtu.format_shape_dtype_string(rhs_shape, dtype), - dimension_numbers, bdims), - "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory} - for lhs_shape, rhs_shape, dimension_numbers in [ - ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), - ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), - ] - for bdims in all_bdims(lhs_shape, rhs_shape) - for dtype in default_dtypes - for rng_factory in [jtu.rand_small])) - def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, - dimension_numbers, bdims, rng_factory): - rng = rng_factory(self.rng()) - dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) - self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), - rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( - shape, onp.dtype(dtype).name, broadcast_sizes, bdims), - "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, - "bdims": bdims, "rng_factory": rng_factory} - for shape in [(), (2, 3)] - for dtype in default_dtypes - for broadcast_sizes in [(), (2,), (1, 2)] - for bdims in all_bdims(shape) - for rng_factory in [jtu.rand_default])) - def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda x: lax.broadcast(x, broadcast_sizes) - self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( - jtu.format_shape_dtype_string(inshape, dtype), - outshape, broadcast_dimensions, bdims), - "inshape": inshape, "dtype": dtype, "outshape": outshape, - "dimensions": broadcast_dimensions, "bdims": bdims, - "rng_factory": rng_factory} - for inshape, outshape, broadcast_dimensions in [ - ([2], [2, 2], [0]), - ([2], [2, 2], [1]), - ([2], [2, 3], [0]), - ([], [2, 3], []), - ] - for dtype in default_dtypes - for bdims in all_bdims(inshape) - for rng_factory in [jtu.rand_default])) - def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng_factory): - rng = rng_factory(self.rng()) - raise SkipTest("this test has failures in some cases") # TODO(mattjj) - op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) - self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format( - jtu.format_shape_dtype_string(arg_shape, onp.float32), - dimensions, bdims), - "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims, - "rng_factory": rng_factory} - for arg_shape, dimensions in [ - [(1,), (0,)], - [(1,), (-1,)], - [(2, 1, 4), (1,)], - [(2, 1, 4), (-2,)], - [(2, 1, 3, 1), (1,)], - [(2, 1, 3, 1), (1, 3)], - [(2, 1, 3, 1), (3,)], - [(2, 1, 3, 1), (1, -1)], - ] - for bdims in all_bdims(arg_shape) - for rng_factory in [jtu.rand_default])) - def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory): - dtype = onp.float32 - rng = rng_factory(self.rng()) - op = lambda x: lax.squeeze(x, dimensions) - self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype), - dimensions, bdims), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory} - for dtype in default_dtypes - for arg_shape, dimensions, out_shape in [ - [(3, 4), None, (12,)], - [(2, 1, 4), None, (8,)], - [(2, 2, 4), None, (2, 8)], - [(2, 2, 4), (0, 1, 2), (2, 8)], - [(2, 2, 4), (1, 0, 2), (8, 2)], - [(2, 2, 4), (2, 1, 0), (4, 2, 2)] - ] - for bdims in all_bdims(arg_shape) - for rng_factory in [jtu.rand_default])) - def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) - self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_pads={}_bdims={}" - .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims), - "shape": shape, "dtype": dtype, "pads": pads, - "rng_factory": jtu.rand_small, "bdims": bdims} - for shape in [(2, 3)] - for bdims in all_bdims(shape) - for dtype in default_dtypes - for pads in [[(1, 2, 1), (0, 1, 0)]])) - def testPad(self, shape, dtype, pads, bdims, rng_factory): - rng = rng_factory(self.rng()) - fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) - self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format( - jtu.format_shape_dtype_string(pred_shape, onp.bool_), - jtu.format_shape_dtype_string(arg_shape, arg_dtype), - bdims), - "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, - "bdims": bdims, "rng_factory": rng_factory} - for arg_shape in [(), (3,), (2, 3)] - for pred_shape in ([(), arg_shape] if arg_shape else [()]) - for bdims in all_bdims(pred_shape, arg_shape, arg_shape) - for arg_dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda c, x, y: lax.select(c < 0, x, y) - self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,), - (onp.bool_, arg_dtype, arg_dtype), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format( - jtu.format_shape_dtype_string(shape, dtype), - start_indices, limit_indices, strides, bdims), - "shape": shape, "dtype": dtype, "starts": start_indices, - "limits": limit_indices, "strides": strides, "bdims": bdims, "rng_factory": rng_factory} - for shape, start_indices, limit_indices, strides in [ - [(3,), (1,), (2,), None], - [(7,), (4,), (7,), None], - [(5,), (1,), (5,), (2,)], - [(8,), (1,), (6,), (2,)], - [(5, 3), (1, 1), (3, 2), None], - [(5, 3), (1, 1), (3, 1), None], - [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], - [(5, 3), (1, 1), (2, 1), (1, 1)], - [(5, 3), (1, 1), (5, 3), (2, 1)], - ] - for bdims in all_bdims(shape) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testSlice(self, shape, dtype, starts, limits, strides, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda x: lax.slice(x, starts, limits, strides) - self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_perm={}_bdims={}".format( - jtu.format_shape_dtype_string(shape, dtype), perm, bdims), - "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims, "rng_factory": rng_factory} - for shape, perm in [ - [(3, 4), (1, 0)], - [(3, 4), (0, 1)], - [(3, 4, 5), (2, 1, 0)], - [(3, 4, 5), (1, 0, 2)], - ] - for bdims in all_bdims(shape) - for dtype in default_dtypes - for rng_factory in [jtu.rand_default])) - def testTranspose(self, shape, dtype, perm, bdims, rng_factory): - rng = rng_factory(self.rng()) - op = lambda x: lax.transpose(x, perm) - self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}" - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, - init_val, bdims), - "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, - "dims": dims, "bdims": bdims, "rng_factory": rng_factory} - for init_val, op, dtypes in [ - (0, lax.add, default_dtypes), - (1, lax.mul, default_dtypes), - (0, lax.max, all_dtypes), # non-monoidal - (-onp.inf, lax.max, float_dtypes), - (dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]), - (dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]), - (dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]), - (dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]), - (onp.inf, lax.min, float_dtypes), - (dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]), - (dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]), - (dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]), - (dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]), - ] - for dtype in dtypes - for shape, dims in [ - [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], - [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] - ] - for bdims in all_bdims(shape) - for rng_factory in [jtu.rand_small])) - def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory): - rng = rng_factory(self.rng()) - init_val = onp.asarray(init_val, dtype=dtype) - fun = lambda operand: lax.reduce(operand, init_val, op, dims) - self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}_padding={}" - .format(op.__name__, onp.dtype(dtype).name, padding), - "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, - "rng_factory": rng_factory} - for init_val, op, dtypes in [ - (0, lax.add, [onp.float32]), - (-onp.inf, lax.max, [onp.float32]), - (onp.inf, lax.min, [onp.float32]), - ] - for dtype in dtypes - for padding in ["VALID", "SAME"] - for rng_factory in [jtu.rand_small])) - def testReduceWindow(self, op, init_val, dtype, padding, rng_factory): - rng = rng_factory(self.rng()) - init_val = onp.asarray(init_val, dtype=dtype) - - all_configs = itertools.chain( - itertools.product( - [(4, 6)], - [(2, 1), (1, 2)], - [(1, 1), (2, 1), (1, 2)]), - itertools.product( - [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], - [(1, 2, 2, 1), (1, 1, 1, 1)])) - - def fun(operand): - return lax.reduce_window(operand, init_val, op, dims, strides, padding) - - for shape, dims, strides in all_configs: - for bdims in all_bdims(shape): - self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_shape={}_axis={}_bdims={}" - .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, - bdims), - "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, - "axis": axis, "rng_factory": rng_factory} - for op, types in [ - (lax.cumsum, [onp.float32, onp.float64]), - (lax.cumprod, [onp.float32, onp.float64]), - ] - for dtype in types - for shape in [[10], [3, 4, 5]] - for axis in range(len(shape)) - for bdims in all_bdims(shape) - for rng_factory in [ - jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) - else jtu.rand_small])) - def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory): - rng = rng_factory(self.rng()) - self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape,), (dtype,), - rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_padding={}".format(onp.dtype(dtype).name, - padding), - "dtype": dtype, "padding": padding, "rng_factory": rng_factory} - for dtype in float_dtypes - for padding in ["VALID", "SAME"] - for rng_factory in [jtu.rand_small])) - @jtu.skip_on_flag("jax_skip_slow_tests", True) - @jtu.ignore_warning(message="Using reduced precision for gradient.*") - def testSelectAndGatherAdd(self, dtype, padding, rng_factory): - if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: - raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu") - rng = rng_factory(self.rng()) - all_configs = itertools.chain( - itertools.product( - [(4, 6)], - [(2, 1), (1, 2)], - [(1, 1), (2, 1), (1, 2)]), - itertools.product( - [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], - [(1, 2, 2, 1), (1, 1, 1, 1)])) - - def fun(operand, tangents): - return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, - strides, padding) - - for shape, dims, strides in all_configs: - for bdims in all_bdims(shape, shape): - self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_bdims={}_fft_ndims={}" - .format(shape, bdims, fft_ndims), - "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims, "rng_factory": rng_factory} - for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)] - for bdims in all_bdims(shape) - for fft_ndims in range(0, min(3, len(shape)) + 1) - for rng_factory in [jtu.rand_default])) - @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. - def testFft(self, fft_ndims, shape, bdims, rng_factory): - rng = rng_factory(self.rng()) - ndims = len(shape) - axes = range(ndims - fft_ndims, ndims) - fft_lengths = [shape[axis] for axis in axes] - op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) - self._CheckBatching(op, 5, bdims, [shape], [onp.complex64], rng) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}" - .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, - slice_sizes, bdims), - "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, - "slice_sizes": slice_sizes, "bdims": bdims} - for dtype in all_dtypes - for shape, idxs, dnums, slice_sizes in [ - ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( - offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), - (1,)), - ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), - (2,)), - ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), - (1, 3)), - ((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( - offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), - (1, 3)), - ] - for bdims in all_bdims(shape, idxs.shape))) - def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): - fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) - self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], - jtu.rand_default(self.rng())) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - idxs, update_shape, dnums, bdims), - "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, - "update_shape": update_shape, "dnums": dnums, "bdims": bdims} - for dtype in float_dtypes - for arg_shape, idxs, update_shape, dnums in [ - ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0,))), - ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ] - for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) - def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): - fun = partial(lax.scatter_add, dimension_numbers=dnums) - self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], - [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), - rtol={onp.float16: 5e-3}) - - def testShapeUsesBuiltinInt(self): - x = lax.iota(onp.int32, 3) + 1 - self.assertIsInstance(x.shape[0], int) # not np.int64 - - def testBroadcastShapesReturnsPythonInts(self): - shape1, shape2 = (1, 2, 3), (2, 3) - out_shape = lax.broadcast_shapes(shape1, shape2) - self.assertTrue(all(type(s) is int for s in out_shape)) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}_bdims={}".format( - jtu.format_shape_dtype_string(shape, dtype), k, bdims), - "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory} - for shape in [(4,), (3, 5, 3)] - for k in [1, 3] - for bdims in all_bdims(shape) - # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: - # The top_k indices for integer arrays with identical entries won't match between - # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. - # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of - # values a bfloat16 can represent exactly to avoid ties. - for dtype, rng_factory in itertools.chain( - zip(float_dtypes + int_dtypes, itertools.repeat(jtu.rand_unique_int))))) - def testTopK(self, shape, dtype, k, bdims, rng_factory): - rng = rng_factory(self.rng()) - # _CheckBatching doesn't work with tuple outputs, so test outputs separately. - op1 = lambda x: lax.top_k(x, k=k)[0] - self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) - op2 = lambda x: lax.top_k(x, k=k)[1] - self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}" - .format(jtu.format_shape_dtype_string(shape, onp.float32), dimension, - arity, bdims), - "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims} - for shape in [(2, 3)] - for dimension in [0, 1] - for arity in range(3) - for bdims in all_bdims(*((shape,) * arity)))) - def testSort(self, shape, dimension, arity, bdims): - rng = jtu.rand_default(self.rng()) - if arity == 1: - fun = partial(lax.sort, dimension=dimension) - self._CheckBatching(fun, 5, bdims, (shape,) * arity, (onp.float32,) * arity, - rng) - else: - for i in range(arity): - fun = lambda *args, i=i: lax.sort(args, dimension=dimension)[i] - self._CheckBatching(fun, 5, bdims, (shape,) * arity, - (onp.float32,) * arity, rng) - - - # TODO Concatenate - # TODO Reverse - # TODO DynamicSlice - # TODO DynamicUpdateSlice - # TODO Collapse - # TODO Scatter - - if __name__ == '__main__': absltest.main() diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py new file mode 100644 index 000000000000..4135f2520e9c --- /dev/null +++ b/tests/lax_vmap_test.py @@ -0,0 +1,686 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +import itertools +from typing import Optional, cast +from unittest import SkipTest + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as onp + +import jax +from jax import api +from jax import dtypes +from jax import lax +from jax import test_util as jtu +from jax.lib import xla_client + +from tests.lax_test import (all_dtypes, CombosWithReplacement, + compatible_shapes, default_dtypes, float_dtypes, + int_dtypes, LAX_OPS) + +from jax.config import config +config.parse_flags_with_absl() +FLAGS = config.FLAGS + +def all_bdims(*shapes): + bdims = (itertools.chain([cast(Optional[int], None)], + range(len(shape) + 1)) for shape in shapes) + return (t for t in itertools.product(*bdims) if not all(e is None for e in t)) + +def add_bdim(bdim_size, bdim, shape): + shape = list(shape) + if bdim is not None: + shape.insert(bdim, bdim_size) + return tuple(shape) + +def slicer(x, bdim): + if bdim is None: + return lambda _: x + else: + return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False) + +def args_slicer(args, bdims): + slicers = list(map(slicer, args, bdims)) + return lambda i: [sl(i) for sl in slicers] + +class LaxVmapTest(jtu.JaxTestCase): + + def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, + rtol=None, atol=None): + batched_shapes = list(jax.util.safe_map(partial(add_bdim, bdim_size), + bdims, shapes)) + args = [rng(shape, dtype) + for shape, dtype in jax.util.safe_zip(batched_shapes, dtypes)] + args_slice = args_slicer(args, bdims) + ans = api.vmap(op, bdims)(*args) + expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)]) + self.assertAllClose(ans, expected, rtol=rtol, atol=atol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_bdims={}".format( + jtu.format_test_name_suffix(rec.op, shapes, + itertools.repeat(dtype)), bdims), + "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, + "dtype": dtype, "bdims": bdims, "tol": rec.tol} + for shape_group in compatible_shapes + for shapes in CombosWithReplacement(shape_group, rec.nargs) + for bdims in all_bdims(*shapes) + for dtype in rec.dtypes) + for rec in LAX_OPS)) + def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): + rng = rng_factory(self.rng()) + op = getattr(lax, op_name) + self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, + atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" + "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" + "_lhs_bdim={}_rhs_bdim={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), + feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "strides": strides, "padding": padding, "lhs_dil": lhs_dil, + "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, + "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, + "feature_group_count": feature_group_count, + "batch_group_count": batch_group_count, + } + for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) + for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ + ((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape + (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape + [(1, 1), (1, 2), (2, 1)], # strides + [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads + [(1, 1), (2, 1)], # lhs_dils + [(1, 1), (2, 2)]) # rhs_dils + for b, i, j in itertools.product([1, 2], repeat=3)] + for strides in all_strides + for rhs_dil in rhs_dils + for lhs_dil in lhs_dils + for dtype in [onp.float32] + for padding in all_pads + for dim_nums, perms in [ + (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), + (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), + (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] + for lhs_bdim in itertools.chain([cast(Optional[int], None)], + range(len(lhs_shape) + 1)) + for rhs_bdim in itertools.chain([cast(Optional[int], None)], + range(len(rhs_shape) + 1)) + if (lhs_bdim, rhs_bdim) != (None, None) + for rng_factory in [jtu.rand_default] + )) + def testConvGeneralDilatedBatching( + self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, + dimension_numbers, perms, feature_group_count, batch_group_count, + lhs_bdim, rhs_bdim, rng_factory): + rng = rng_factory(self.rng()) + tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 + + # permute shapes to match dim_spec, scale by feature_group_count + lhs_perm, rhs_perm = perms + lhs_shape = list(onp.take(lhs_shape, lhs_perm)) + rhs_shape = list(onp.take(rhs_shape, rhs_perm)) + + conv = partial(lax.conv_general_dilated, window_strides=strides, + padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count, + precision=lax.Precision.HIGHEST) + self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), + (dtype, dtype), rng, rtol=tol, atol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( + shape, from_dtype, to_dtype, bdims), + "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, + "bdims": bdims, "rng_factory": rng_factory} + for from_dtype, to_dtype in itertools.product( + [onp.float32, onp.int32, "float32", "int32"], repeat=2) + for shape in [(2, 3)] + for bdims in all_bdims(shape) + for rng_factory in [jtu.rand_default])) + def testConvertElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda x: lax.convert_element_type(x, to_dtype) + self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( + shape, from_dtype, to_dtype, bdims), + "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, + "bdims": bdims, "rng_factory": rng_factory} + for from_dtype, to_dtype in itertools.product( + [onp.float32, onp.int32, "float32", "int32"], repeat=2) + for shape in [(2, 3)] + for bdims in all_bdims(shape) + for rng_factory in [jtu.rand_default])) + def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda x: lax.bitcast_convert_type(x, to_dtype) + self._CheckBatching(op, 10, bdims, (shape,), (from_dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}" + .format(jtu.format_shape_dtype_string(min_shape, dtype), + jtu.format_shape_dtype_string(operand_shape, dtype), + jtu.format_shape_dtype_string(max_shape, dtype), + bdims), + "min_shape": min_shape, "operand_shape": operand_shape, + "max_shape": max_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory} + for min_shape, operand_shape, max_shape in [ + [(), (2, 3), ()], + [(2, 3), (2, 3), ()], + [(), (2, 3), (2, 3)], + [(2, 3), (2, 3), (2, 3)], + ] + for dtype in default_dtypes + for bdims in all_bdims(min_shape, operand_shape, max_shape) + for rng_factory in [jtu.rand_default])) + def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims, rng_factory): + rng = rng_factory(self.rng()) + raise SkipTest("batching rule for clamp not implemented") # TODO(mattj) + shapes = [min_shape, operand_shape, max_shape] + self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( + jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + bdims), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "bdims": bdims, "rng_factory": rng_factory} + for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)] + for bdims in all_bdims(lhs_shape, rhs_shape) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default])) + def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = partial(lax.dot, precision=lax.Precision.HIGHEST) + self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), + rng, rtol={onp.float16: 5e-2, onp.float64: 5e-14}) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + lhs_contracting, rhs_contracting, bdims), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, + "bdims": bdims, "rng_factory": rng_factory} + for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ + [(3, 5), (2, 5), [1], [1]], + [(5, 3), (5, 2), [0], [0]], + [(5, 3, 2), (5, 2, 4), [0], [0]], + [(5, 3, 2), (5, 2, 4), [0,2], [0,1]], + [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], + [(3, 2), (2, 4), [1], [0]], + ] + for bdims in all_bdims(lhs_shape, rhs_shape) + for dtype in default_dtypes + for rng_factory in [jtu.rand_small])) + def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, + lhs_contracting, rhs_contracting, bdims, rng_factory): + rng = rng_factory(self.rng()) + dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) + dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) + self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), + rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}" + .format(jtu.format_shape_dtype_string(lhs_shape, dtype), + jtu.format_shape_dtype_string(rhs_shape, dtype), + dimension_numbers, bdims), + "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, + "dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory} + for lhs_shape, rhs_shape, dimension_numbers in [ + ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), + ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), + ] + for bdims in all_bdims(lhs_shape, rhs_shape) + for dtype in default_dtypes + for rng_factory in [jtu.rand_small])) + def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, + dimension_numbers, bdims, rng_factory): + rng = rng_factory(self.rng()) + dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) + self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), + rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( + shape, onp.dtype(dtype).name, broadcast_sizes, bdims), + "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, + "bdims": bdims, "rng_factory": rng_factory} + for shape in [(), (2, 3)] + for dtype in default_dtypes + for broadcast_sizes in [(), (2,), (1, 2)] + for bdims in all_bdims(shape) + for rng_factory in [jtu.rand_default])) + def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda x: lax.broadcast(x, broadcast_sizes) + self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( + jtu.format_shape_dtype_string(inshape, dtype), + outshape, broadcast_dimensions, bdims), + "inshape": inshape, "dtype": dtype, "outshape": outshape, + "dimensions": broadcast_dimensions, "bdims": bdims, + "rng_factory": rng_factory} + for inshape, outshape, broadcast_dimensions in [ + ([2], [2, 2], [0]), + ([2], [2, 2], [1]), + ([2], [2, 3], [0]), + ([], [2, 3], []), + ] + for dtype in default_dtypes + for bdims in all_bdims(inshape) + for rng_factory in [jtu.rand_default])) + def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng_factory): + rng = rng_factory(self.rng()) + raise SkipTest("this test has failures in some cases") # TODO(mattjj) + op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) + self._CheckBatching(op, 5, bdims, (inshape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_dimensions={}_bdims={}".format( + jtu.format_shape_dtype_string(arg_shape, onp.float32), + dimensions, bdims), + "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims, + "rng_factory": rng_factory} + for arg_shape, dimensions in [ + [(1,), (0,)], + [(1,), (-1,)], + [(2, 1, 4), (1,)], + [(2, 1, 4), (-2,)], + [(2, 1, 3, 1), (1,)], + [(2, 1, 3, 1), (1, 3)], + [(2, 1, 3, 1), (3,)], + [(2, 1, 3, 1), (1, -1)], + ] + for bdims in all_bdims(arg_shape) + for rng_factory in [jtu.rand_default])) + def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory): + dtype = onp.float32 + rng = rng_factory(self.rng()) + op = lambda x: lax.squeeze(x, dimensions) + self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + dimensions, bdims), + "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, + "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory} + for dtype in default_dtypes + for arg_shape, dimensions, out_shape in [ + [(3, 4), None, (12,)], + [(2, 1, 4), None, (8,)], + [(2, 2, 4), None, (2, 8)], + [(2, 2, 4), (0, 1, 2), (2, 8)], + [(2, 2, 4), (1, 0, 2), (8, 2)], + [(2, 2, 4), (2, 1, 0), (4, 2, 2)] + ] + for bdims in all_bdims(arg_shape) + for rng_factory in [jtu.rand_default])) + def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) + self._CheckBatching(op, 10, bdims, (arg_shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_pads={}_bdims={}" + .format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims), + "shape": shape, "dtype": dtype, "pads": pads, + "rng_factory": jtu.rand_small, "bdims": bdims} + for shape in [(2, 3)] + for bdims in all_bdims(shape) + for dtype in default_dtypes + for pads in [[(1, 2, 1), (0, 1, 0)]])) + def testPad(self, shape, dtype, pads, bdims, rng_factory): + rng = rng_factory(self.rng()) + fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) + self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_predshape={}_argshapes={}_bdims={}".format( + jtu.format_shape_dtype_string(pred_shape, onp.bool_), + jtu.format_shape_dtype_string(arg_shape, arg_dtype), + bdims), + "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, + "bdims": bdims, "rng_factory": rng_factory} + for arg_shape in [(), (3,), (2, 3)] + for pred_shape in ([(), arg_shape] if arg_shape else [()]) + for bdims in all_bdims(pred_shape, arg_shape, arg_shape) + for arg_dtype in default_dtypes + for rng_factory in [jtu.rand_default])) + def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda c, x, y: lax.select(c < 0, x, y) + self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,), + (onp.bool_, arg_dtype, arg_dtype), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".format( + jtu.format_shape_dtype_string(shape, dtype), + start_indices, limit_indices, strides, bdims), + "shape": shape, "dtype": dtype, "starts": start_indices, + "limits": limit_indices, "strides": strides, "bdims": bdims, "rng_factory": rng_factory} + for shape, start_indices, limit_indices, strides in [ + [(3,), (1,), (2,), None], + [(7,), (4,), (7,), None], + [(5,), (1,), (5,), (2,)], + [(8,), (1,), (6,), (2,)], + [(5, 3), (1, 1), (3, 2), None], + [(5, 3), (1, 1), (3, 1), None], + [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], + [(5, 3), (1, 1), (2, 1), (1, 1)], + [(5, 3), (1, 1), (5, 3), (2, 1)], + ] + for bdims in all_bdims(shape) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default])) + def testSlice(self, shape, dtype, starts, limits, strides, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda x: lax.slice(x, starts, limits, strides) + self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_perm={}_bdims={}".format( + jtu.format_shape_dtype_string(shape, dtype), perm, bdims), + "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims, "rng_factory": rng_factory} + for shape, perm in [ + [(3, 4), (1, 0)], + [(3, 4), (0, 1)], + [(3, 4, 5), (2, 1, 0)], + [(3, 4, 5), (1, 0, 2)], + ] + for bdims in all_bdims(shape) + for dtype in default_dtypes + for rng_factory in [jtu.rand_default])) + def testTranspose(self, shape, dtype, perm, bdims, rng_factory): + rng = rng_factory(self.rng()) + op = lambda x: lax.transpose(x, perm) + self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}" + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, + init_val, bdims), + "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, + "dims": dims, "bdims": bdims, "rng_factory": rng_factory} + for init_val, op, dtypes in [ + (0, lax.add, default_dtypes), + (1, lax.mul, default_dtypes), + (0, lax.max, all_dtypes), # non-monoidal + (-onp.inf, lax.max, float_dtypes), + (dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]), + (dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]), + (dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]), + (dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]), + (onp.inf, lax.min, float_dtypes), + (dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]), + (dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]), + (dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]), + (dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]), + ] + for dtype in dtypes + for shape, dims in [ + [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], + [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] + ] + for bdims in all_bdims(shape) + for rng_factory in [jtu.rand_small])) + def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory): + rng = rng_factory(self.rng()) + init_val = onp.asarray(init_val, dtype=dtype) + fun = lambda operand: lax.reduce(operand, init_val, op, dims) + self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_dtype={}_padding={}" + .format(op.__name__, onp.dtype(dtype).name, padding), + "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, + "rng_factory": rng_factory} + for init_val, op, dtypes in [ + (0, lax.add, [onp.float32]), + (-onp.inf, lax.max, [onp.float32]), + (onp.inf, lax.min, [onp.float32]), + ] + for dtype in dtypes + for padding in ["VALID", "SAME"] + for rng_factory in [jtu.rand_small])) + def testReduceWindow(self, op, init_val, dtype, padding, rng_factory): + rng = rng_factory(self.rng()) + init_val = onp.asarray(init_val, dtype=dtype) + + all_configs = itertools.chain( + itertools.product( + [(4, 6)], + [(2, 1), (1, 2)], + [(1, 1), (2, 1), (1, 2)]), + itertools.product( + [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)])) + + def fun(operand): + return lax.reduce_window(operand, init_val, op, dims, strides, padding) + + for shape, dims, strides in all_configs: + for bdims in all_bdims(shape): + self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_shape={}_axis={}_bdims={}" + .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, + bdims), + "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, + "axis": axis, "rng_factory": rng_factory} + for op, types in [ + (lax.cumsum, [onp.float32, onp.float64]), + (lax.cumprod, [onp.float32, onp.float64]), + ] + for dtype in types + for shape in [[10], [3, 4, 5]] + for axis in range(len(shape)) + for bdims in all_bdims(shape) + for rng_factory in [ + jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) + else jtu.rand_small])) + def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory): + rng = rng_factory(self.rng()) + self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape,), (dtype,), + rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_dtype={}_padding={}".format(onp.dtype(dtype).name, + padding), + "dtype": dtype, "padding": padding, "rng_factory": rng_factory} + for dtype in float_dtypes + for padding in ["VALID", "SAME"] + for rng_factory in [jtu.rand_small])) + @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.ignore_warning(message="Using reduced precision for gradient.*") + def testSelectAndGatherAdd(self, dtype, padding, rng_factory): + if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: + raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu") + rng = rng_factory(self.rng()) + all_configs = itertools.chain( + itertools.product( + [(4, 6)], + [(2, 1), (1, 2)], + [(1, 1), (2, 1), (1, 2)]), + itertools.product( + [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], + [(1, 2, 2, 1), (1, 1, 1, 1)])) + + def fun(operand, tangents): + return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, + strides, padding) + + for shape, dims, strides in all_configs: + for bdims in all_bdims(shape, shape): + self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_bdims={}_fft_ndims={}" + .format(shape, bdims, fft_ndims), + "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims, "rng_factory": rng_factory} + for shape in [(5,), (3, 4, 5), (2, 3, 4, 5)] + for bdims in all_bdims(shape) + for fft_ndims in range(0, min(3, len(shape)) + 1) + for rng_factory in [jtu.rand_default])) + @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. + def testFft(self, fft_ndims, shape, bdims, rng_factory): + rng = rng_factory(self.rng()) + ndims = len(shape) + axes = range(ndims - fft_ndims, ndims) + fft_lengths = [shape[axis] for axis in axes] + op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) + self._CheckBatching(op, 5, bdims, [shape], [onp.complex64], rng) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}" + .format(jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, + slice_sizes, bdims), + "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, + "slice_sizes": slice_sizes, "bdims": bdims} + for dtype in all_dtypes + for shape, idxs, dnums, slice_sizes in [ + ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), + (1,)), + ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), + (2,)), + ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), + (1, 3)), + ((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( + offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), + (1, 3)), + ] + for bdims in all_bdims(shape, idxs.shape))) + def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): + fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) + self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], + jtu.rand_default(self.rng())) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + idxs, update_shape, dnums, bdims), + "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, + "update_shape": update_shape, "dnums": dnums, "bdims": bdims} + for dtype in float_dtypes + for arg_shape, idxs, update_shape, dnums in [ + ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ] + for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) + def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): + fun = partial(lax.scatter_add, dimension_numbers=dnums) + self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], + [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), + rtol={onp.float16: 5e-3}) + + def testShapeUsesBuiltinInt(self): + x = lax.iota(onp.int32, 3) + 1 + self.assertIsInstance(x.shape[0], int) # not np.int64 + + def testBroadcastShapesReturnsPythonInts(self): + shape1, shape2 = (1, 2, 3), (2, 3) + out_shape = lax.broadcast_shapes(shape1, shape2) + self.assertTrue(all(type(s) is int for s in out_shape)) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}_bdims={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, bdims), + "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory} + for shape in [(4,), (3, 5, 3)] + for k in [1, 3] + for bdims in all_bdims(shape) + # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: + # The top_k indices for integer arrays with identical entries won't match between + # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. + # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of + # values a bfloat16 can represent exactly to avoid ties. + for dtype, rng_factory in itertools.chain( + zip(float_dtypes + int_dtypes, itertools.repeat(jtu.rand_unique_int))))) + def testTopK(self, shape, dtype, k, bdims, rng_factory): + rng = rng_factory(self.rng()) + # _CheckBatching doesn't work with tuple outputs, so test outputs separately. + op1 = lambda x: lax.top_k(x, k=k)[0] + self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) + op2 = lambda x: lax.top_k(x, k=k)[1] + self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) + + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_dimension={}_arity={}_bdims={}" + .format(jtu.format_shape_dtype_string(shape, onp.float32), dimension, + arity, bdims), + "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims} + for shape in [(2, 3)] + for dimension in [0, 1] + for arity in range(3) + for bdims in all_bdims(*((shape,) * arity)))) + def testSort(self, shape, dimension, arity, bdims): + rng = jtu.rand_default(self.rng()) + if arity == 1: + fun = partial(lax.sort, dimension=dimension) + self._CheckBatching(fun, 5, bdims, (shape,) * arity, (onp.float32,) * arity, + rng) + else: + for i in range(arity): + fun = lambda *args, i=i: lax.sort(args, dimension=dimension)[i] + self._CheckBatching(fun, 5, bdims, (shape,) * arity, + (onp.float32,) * arity, rng) + + + # TODO Concatenate + # TODO Reverse + # TODO DynamicSlice + # TODO DynamicUpdateSlice + # TODO Collapse + # TODO Scatter + + +if __name__ == '__main__': + absltest.main()