diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp.py b/tensorflow_probability/python/experimental/fastgp/fast_gp.py index b7c299af2a..f382271f6c 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_gp.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp.py @@ -46,15 +46,15 @@ class GaussianProcessConfig: # The maximum number of iterations to run conjugate gradients # for when calculating the yt_inv_y part of the log prob. - cg_iters: int = 20 + cg_iters: int = 25 # The name of a preconditioner in the preconditioner.PRECONDITIONER_REGISTRY - # or 'auto' which will used truncated_svd when n is large and - # partial_cholesky_split when is small. + # or 'auto' which will used truncated_randomized_svd_plus_scaling when n is + # large and partial_cholesky_split when is small. preconditioner: str = 'auto' # Use a preconditioner based on a low rank # approximation of this rank. Note that not all preconditioners have # adjustable ranks. - preconditioner_rank: int = 20 + preconditioner_rank: int = 25 # Some preconditioners (like `truncated_svd`) can # get better approximation accuracy for running for more iterations (even # at a fixed rank size). This parameter controls that. Note that the @@ -69,9 +69,9 @@ class GaussianProcessConfig: precondition_before_jitter: str = 'auto' # Either `normal`, `normal_qmc`, `normal_orthogonal` or # `rademacher`. `normal_qmc` is only valid for n <= 10000 - probe_vector_type: str = 'normal_orthogonal' + probe_vector_type: str = 'rademacher' # The number of probe vectors to use when estimating the log det. - num_probe_vectors: int = 30 + num_probe_vectors: int = 35 # One of 'slq' (for stochastic Lancos quadrature) or # 'r1', 'r2', 'r3', 'r4', 'r5', or 'r6' for the rational function # approximation of the given order. @@ -258,7 +258,9 @@ def _log_det(self, key, is_missing=None): mask_loc=False, ) - is_scaling_preconditioner = self._config.preconditioner.endswith('scaling') + is_scaling_preconditioner = preconditioners.resolve_preconditioner( + self._config.preconditioner, covariance, + self._config.preconditioner_rank).endswith('scaling') precondition_before_jitter = ( self._config.precondition_before_jitter == 'true' or ( diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py b/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py index c017580570..18ff82fbea 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py @@ -450,7 +450,7 @@ def slow_log_prob(amplitude, length_scale, noise, jitter): self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3)) direct_slow_value = slow_log_prob( self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3)) - self.assertAlmostEqual(direct_value, direct_slow_value, delta=3e-4) + self.assertAlmostEqual(direct_value, direct_slow_value, delta=4e-4) slow_value, slow_gradient = jax.value_and_grad( slow_log_prob, argnums=[0, 1, 2, 3] diff --git a/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py b/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py index 98e0b0c106..e6b3a5e5fe 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py @@ -186,7 +186,7 @@ def test_gaussian_process_log_prob2(self): fast_ll = jnp.sum(fgp.log_prob(samples, key=jax.random.PRNGKey(1))) slow_ll = jnp.sum(sgp.log_prob(samples)) - np.testing.assert_allclose(fast_ll, slow_ll, rtol=2e-4) + np.testing.assert_allclose(fast_ll, slow_ll, rtol=4e-4) def test_gaussian_process_log_prob_jits(self): kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( @@ -345,12 +345,12 @@ def slow_log_prob(amplitude, length_scale, noise): self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) direct_slow_value = slow_log_prob( self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) - np.testing.assert_allclose(direct_value, direct_slow_value, rtol=4e-4) + np.testing.assert_allclose(direct_value, direct_slow_value, rtol=1e-3) slow_value, slow_gradient = jax.value_and_grad( slow_log_prob, argnums=[0, 1, 2] )(self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) - np.testing.assert_allclose(value, slow_value, rtol=4e-4) + np.testing.assert_allclose(value, slow_value, rtol=1e-3) slow_d_amp, slow_d_length_scale, slow_d_noise = slow_gradient np.testing.assert_allclose(d_amp, slow_d_amp, rtol=1e-4) np.testing.assert_allclose(d_length_scale, slow_d_length_scale, rtol=1e-4) diff --git a/tensorflow_probability/python/experimental/fastgp/preconditioners.py b/tensorflow_probability/python/experimental/fastgp/preconditioners.py index 71ad8d0806..be80af4729 100644 --- a/tensorflow_probability/python/experimental/fastgp/preconditioners.py +++ b/tensorflow_probability/python/experimental/fastgp/preconditioners.py @@ -721,18 +721,27 @@ def __init__( } +def resolve_preconditioner( + preconditioner_name: str, + M: tf2jax.linalg.LinearOperator, + rank: int) -> str: + """Return the resolved preconditioner_name.""" + if preconditioner_name == 'auto': + n = M.shape[-1] + if 5 * rank >= n: + return 'partial_cholesky_split' + else: + return 'truncated_randomized_svd_plus_scaling' + return preconditioner_name + + @jax.named_call def get_preconditioner( preconditioner_name: str, M: tf2jax.linalg.LinearOperator, **kwargs ) -> SplitPreconditioner: """Return the preconditioner of the given type for the given matrix.""" - if preconditioner_name == 'auto': - n = M.shape[-1] - rank = kwargs.get('rank', 20) - if 5 * rank >= n: - preconditioner_name = 'partial_cholesky_split' - else: - preconditioner_name = 'truncated_svd' + preconditioner_name = resolve_preconditioner( + preconditioner_name, M, kwargs.get('rank', 20)) try: return PRECONDITIONER_REGISTRY[preconditioner_name](M, **kwargs) except KeyError as key_error: