diff --git a/tensorflow_probability/python/distributions/power_spherical.py b/tensorflow_probability/python/distributions/power_spherical.py index 820566e930..acfeafa13a 100644 --- a/tensorflow_probability/python/distributions/power_spherical.py +++ b/tensorflow_probability/python/distributions/power_spherical.py @@ -217,8 +217,9 @@ def _log_normalization(self, concentration=None, mean_direction=None): concentration1 = concentration + (event_size - 1.) / 2. concentration0 = (event_size - 1.) / 2. - return ((concentration1 + concentration0) * np.log(2.) + - concentration0 * np.log(np.pi) + + np_dtype = dtype_util.as_numpy_dtype(concentration.dtype) + return ((concentration1 + concentration0) * np.log(2.).astype(np_dtype) + + concentration0 * np.log(np.pi).astype(np_dtype) + special.log_gamma_difference(concentration0, concentration1)) def _sample_control_dependencies(self, samples):