diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 5fef447851a6..6a1eaf8ca4bf 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -457,9 +457,8 @@ def f(): JAX_COMPOUND_OP_RECORDS))) def testOp(self, onp_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes, tolerance, inexact): - if onp_op is onp.float_power: - onp_op = jtu.ignore_warning(category=RuntimeWarning, - message="invalid value.*")(onp_op) + onp_op = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value.*")(onp_op) rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes, onp_arrays=False) diff --git a/tests/lax_test.py b/tests/lax_test.py index f1880e4d6991..bcfcc36d1361 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -109,7 +109,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None): op_record("rsqrt", 1, float_dtypes + complex_dtypes, jtu.rand_positive), op_record("square", 1, float_dtypes + complex_dtypes, jtu.rand_default), op_record("reciprocal", 1, float_dtypes + complex_dtypes, jtu.rand_positive), - op_record("tan", 1, float_dtypes, jtu.rand_default, {onp.float32: 1e-5}), + op_record("tan", 1, float_dtypes, jtu.rand_default, {onp.float32: 3e-5}), op_record("asin", 1, float_dtypes, jtu.rand_small), op_record("acos", 1, float_dtypes, jtu.rand_small), op_record("atan", 1, float_dtypes, jtu.rand_small), diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 3a5e4cf15475..9cf11302f21f 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -114,7 +114,8 @@ def args_maker(): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, rtol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True, + rtol={onp.float32: 2e-3, onp.float64: 1e-4}) @genNamedParametersNArgs(3, jtu.rand_default) def testCauchyLogPdf(self, rng_factory, shapes, dtypes):