Skip to content

Commit f051e03

Browse files
authored
Merge pull request #933 from emilyfertig/r0.10
R0.10
2 parents 3244d86 + 73ce8fa commit f051e03

File tree

5 files changed

+57
-5
lines changed

5 files changed

+57
-5
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,12 +1521,16 @@ multi_substrate_py_test(
15211521
name = "softplus_test",
15221522
size = "small",
15231523
srcs = ["softplus_test.py"],
1524+
jax_size = "medium",
15241525
deps = [
15251526
":bijector_test_util",
15261527
":bijectors",
1528+
# absl/testing:parameterized dep,
15271529
# numpy dep,
15281530
# tensorflow dep,
15291531
"//tensorflow_probability/python/internal:test_util",
1532+
"//tensorflow_probability/python/math",
1533+
# tensorflow/compiler/jit dep,
15301534
],
15311535
)
15321536

tensorflow_probability/python/bijectors/softplus.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,31 @@
3333
]
3434

3535

36+
JAX_MODE = False # Overwritten by rewrite script.
37+
38+
39+
# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
40+
if JAX_MODE:
41+
_stable_grad_softplus = tf.nn.softplus
42+
else:
43+
44+
@tf.custom_gradient
45+
def _stable_grad_softplus(x):
46+
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
47+
x = tf.convert_to_tensor(x)
48+
if x.dtype == tf.float64:
49+
cutoff = -20
50+
else:
51+
cutoff = -9
52+
53+
y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))
54+
55+
def grad_fn(dy):
56+
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))
57+
58+
return y, grad_fn
59+
60+
3661
class Softplus(bijector.Bijector):
3762
"""Bijector which computes `Y = g(X) = Log[1 + exp(X)]`.
3863
@@ -101,9 +126,9 @@ def _is_increasing(cls):
101126

102127
def _forward(self, x):
103128
if self.hinge_softness is None:
104-
return tf.math.softplus(x)
129+
return _stable_grad_softplus(x)
105130
hinge_softness = tf.cast(self.hinge_softness, x.dtype)
106-
return hinge_softness * tf.math.softplus(x / hinge_softness)
131+
return hinge_softness * _stable_grad_softplus(x / hinge_softness)
107132

108133
def _inverse(self, y):
109134
if self.hinge_softness is None:

tensorflow_probability/python/bijectors/softplus_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
# Dependency imports
2222

23+
from absl.testing import parameterized
2324
import numpy as np
2425
import tensorflow.compat.v2 as tf
2526
from tensorflow_probability.python import bijectors as tfb
27+
from tensorflow_probability.python import math as tfp_math
2628
from tensorflow_probability.python.bijectors import bijector_test_util
2729
from tensorflow_probability.python.internal import test_util
2830

@@ -149,6 +151,25 @@ def testVariableHingeSoftness(self):
149151
with tf.control_dependencies([hinge_softness.assign(0.)]):
150152
self.evaluate(b.forward(0.5))
151153

154+
@parameterized.named_parameters(
155+
('32bitGraph', np.float32, False),
156+
('64bitGraph', np.float64, False),
157+
('32bitXLA', np.float32, True),
158+
('64bitXLA', np.float64, True),
159+
)
160+
@test_util.numpy_disable_gradient_test
161+
def testLeftTailGrad(self, dtype, do_compile):
162+
x = np.linspace(-50., -8., 1000).astype(dtype)
163+
164+
@tf.function(autograph=False, experimental_compile=do_compile)
165+
def fn(x):
166+
return tf.math.log(tfb.Softplus().forward(x))
167+
168+
_, grad = tfp_math.value_and_gradient(fn, x)
169+
170+
true_grad = 1 / (1 + np.exp(-x)) / np.log1p(np.exp(x))
171+
self.assertAllClose(true_grad, self.evaluate(grad), atol=1e-3)
172+
152173

153174
if __name__ == '__main__':
154175
tf.test.main()

tensorflow_probability/python/distributions/joint_distribution_named.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,5 +287,7 @@ def _convert_to_dict(x):
287287
if isinstance(x, collections.OrderedDict):
288288
return x
289289
if hasattr(x, '_asdict'):
290-
return x._asdict()
290+
# Wrap with `OrderedDict` to indicate that namedtuples have a well-defined
291+
# order (by default, they convert to just `dict` in Python 3.8+).
292+
return collections.OrderedDict(x._asdict())
291293
return dict(x)

tensorflow_probability/python/layers/distribution_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pickle
2626

2727
# Dependency imports
28-
from cloudpickle import CloudPickler
28+
from cloudpickle.cloudpickle import CloudPickler
2929
import numpy as np
3030
import six
3131
import tensorflow.compat.v2 as tf
@@ -47,7 +47,7 @@
4747
from tensorflow_probability.python.distributions import variational_gaussian_process as variational_gaussian_process_lib
4848
from tensorflow_probability.python.internal import distribution_util as dist_util
4949
from tensorflow_probability.python.layers.internal import distribution_tensor_coercible as dtc
50-
from tensorflow_probability.python.layers.internal import tensor_tuple as tensor_tuple
50+
from tensorflow_probability.python.layers.internal import tensor_tuple
5151
from tensorflow.python.keras.utils import tf_utils as keras_tf_utils # pylint: disable=g-direct-tensorflow-import
5252

5353

0 commit comments

Comments
 (0)