From 07c97046c18e9acfa0a8a15809fc6a9f54f56ff2 Mon Sep 17 00:00:00 2001 From: Christopher Suter Date: Mon, 8 Jan 2024 09:11:12 -0800 Subject: [PATCH] Fix tf.where-induced nan grads in NormalInverseGaussian Fixes #1778 PiperOrigin-RevId: 596612661 --- .../distributions/normal_inverse_gaussian.py | 32 ++++++++++++++++--- .../normal_inverse_gaussian_test.py | 12 +++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/tensorflow_probability/python/distributions/normal_inverse_gaussian.py b/tensorflow_probability/python/distributions/normal_inverse_gaussian.py index 026d4270ff..57e3fb9ba2 100644 --- a/tensorflow_probability/python/distributions/normal_inverse_gaussian.py +++ b/tensorflow_probability/python/distributions/normal_inverse_gaussian.py @@ -37,11 +37,35 @@ def _log1px2(x): + """Safely compute log(1 + x ** 2). + + For small x, use log1p(x ** 2). For large x(x >> 1), use 2 * log(x). Also + avoid nan grad using double-where for x ~= 0. + + Args: + x: float `Tensor`. + + Returns: + y: log(1 + x ** 2). + """ + # The idea with this is to use 2 log(x) when x** 2 >> 1, so that adding 1 + # doesn't matter. This happens when x >> 1 / sqrt(eps). But this causes + # grad problems for zero input: + # + # If x is zero, the log(1 + x**2) is log(1) = 0. But then 2 * log(x) is + # 2 * log(0) = 2 * -Inf, which causes problems. So for 0 input, we need a safe + # value for the negative case and use the double-where trick + # (see, eg, https://github.com/google/jax/issues/1052) + finfo = np.finfo(dtype_util.as_numpy_dtype(x.dtype)) + is_basically_zero = tf.abs(x) < finfo.tiny + safe_x = tf.where(is_basically_zero, tf.ones_like(x), x) return tf.where( - tf.abs(x) * np.sqrt(np.finfo( - dtype_util.as_numpy_dtype(x.dtype)).eps) <= 1., - tf.math.log1p(x**2.), - 2 * tf.math.log(tf.math.abs(x))) + is_basically_zero, + tf.abs(x), + tf.where( + tf.abs(x) * np.sqrt(finfo.eps) <= 1., + tf.math.log1p(safe_x**2.), + 2 * tf.math.log(tf.math.abs(safe_x)))) class NormalInverseGaussian(distribution.AutoCompositeTensorDistribution): diff --git a/tensorflow_probability/python/distributions/normal_inverse_gaussian_test.py b/tensorflow_probability/python/distributions/normal_inverse_gaussian_test.py index c406ed96c7..9a9176a2bd 100644 --- a/tensorflow_probability/python/distributions/normal_inverse_gaussian_test.py +++ b/tensorflow_probability/python/distributions/normal_inverse_gaussian_test.py @@ -17,6 +17,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import normal_inverse_gaussian as nig from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.math import gradient @test_util.test_all_tf_execution_regimes @@ -218,6 +219,17 @@ def testModifiedVariableAssertion(self): with tf.control_dependencies([skewness.assign(-2.)]): self.evaluate(normal_inverse_gaussian.mean()) + @test_util.numpy_disable_gradient_test + def testDoubleWhere(self): + loc = 0. + + def f(x): + return nig.NormalInverseGaussian( + loc=x, scale=2., tailweight=1., skewness=2.).log_prob(loc) + + _, g = gradient.value_and_gradient(f, loc) + self.assertAllNotNan(g) + class NormalInverseGaussianTestFloat32( test_util.TestCase, _NormalInverseGaussianTest):