Skip to content

Commit

Permalink
Update default parameter values.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606667138
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Feb 13, 2024
1 parent d3116d7 commit be4732f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
16 changes: 9 additions & 7 deletions tensorflow_probability/python/experimental/fastgp/fast_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ class GaussianProcessConfig:

# The maximum number of iterations to run conjugate gradients
# for when calculating the yt_inv_y part of the log prob.
cg_iters: int = 20
cg_iters: int = 25
# The name of a preconditioner in the preconditioner.PRECONDITIONER_REGISTRY
# or 'auto' which will used truncated_svd when n is large and
# partial_cholesky_split when is small.
# or 'auto' which will used truncated_randomized_svd_plus_scaling when n is
# large and partial_cholesky_split when is small.
preconditioner: str = 'auto'
# Use a preconditioner based on a low rank
# approximation of this rank. Note that not all preconditioners have
# adjustable ranks.
preconditioner_rank: int = 20
preconditioner_rank: int = 25
# Some preconditioners (like `truncated_svd`) can
# get better approximation accuracy for running for more iterations (even
# at a fixed rank size). This parameter controls that. Note that the
Expand All @@ -69,9 +69,9 @@ class GaussianProcessConfig:
precondition_before_jitter: str = 'auto'
# Either `normal`, `normal_qmc`, `normal_orthogonal` or
# `rademacher`. `normal_qmc` is only valid for n <= 10000
probe_vector_type: str = 'normal_orthogonal'
probe_vector_type: str = 'rademacher'
# The number of probe vectors to use when estimating the log det.
num_probe_vectors: int = 30
num_probe_vectors: int = 35
# One of 'slq' (for stochastic Lancos quadrature) or
# 'r1', 'r2', 'r3', 'r4', 'r5', or 'r6' for the rational function
# approximation of the given order.
Expand Down Expand Up @@ -258,7 +258,9 @@ def _log_det(self, key, is_missing=None):
mask_loc=False,
)

is_scaling_preconditioner = self._config.preconditioner.endswith('scaling')
is_scaling_preconditioner = preconditioners.resolve_preconditioner(
self._config.preconditioner, covariance,
self._config.preconditioner_rank).endswith('scaling')
precondition_before_jitter = (
self._config.precondition_before_jitter == 'true'
or (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def slow_log_prob(amplitude, length_scale, noise, jitter):
self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3))
direct_slow_value = slow_log_prob(
self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3))
self.assertAlmostEqual(direct_value, direct_slow_value, delta=3e-4)
self.assertAlmostEqual(direct_value, direct_slow_value, delta=4e-4)

slow_value, slow_gradient = jax.value_and_grad(
slow_log_prob, argnums=[0, 1, 2, 3]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_gaussian_process_log_prob2(self):

fast_ll = jnp.sum(fgp.log_prob(samples, key=jax.random.PRNGKey(1)))
slow_ll = jnp.sum(sgp.log_prob(samples))
np.testing.assert_allclose(fast_ll, slow_ll, rtol=2e-4)
np.testing.assert_allclose(fast_ll, slow_ll, rtol=4e-4)

def test_gaussian_process_log_prob_jits(self):
kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(
Expand Down Expand Up @@ -345,12 +345,12 @@ def slow_log_prob(amplitude, length_scale, noise):
self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3))
direct_slow_value = slow_log_prob(
self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3))
np.testing.assert_allclose(direct_value, direct_slow_value, rtol=4e-4)
np.testing.assert_allclose(direct_value, direct_slow_value, rtol=1e-3)

slow_value, slow_gradient = jax.value_and_grad(
slow_log_prob, argnums=[0, 1, 2]
)(self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3))
np.testing.assert_allclose(value, slow_value, rtol=4e-4)
np.testing.assert_allclose(value, slow_value, rtol=1e-3)
slow_d_amp, slow_d_length_scale, slow_d_noise = slow_gradient
np.testing.assert_allclose(d_amp, slow_d_amp, rtol=1e-4)
np.testing.assert_allclose(d_length_scale, slow_d_length_scale, rtol=1e-4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,18 +721,27 @@ def __init__(
}


def resolve_preconditioner(
preconditioner_name: str,
M: tf2jax.linalg.LinearOperator,
rank: int) -> str:
"""Return the resolved preconditioner_name."""
if preconditioner_name == 'auto':
n = M.shape[-1]
if 5 * rank >= n:
return 'partial_cholesky_split'
else:
return 'truncated_randomized_svd_plus_scaling'
return preconditioner_name


@jax.named_call
def get_preconditioner(
preconditioner_name: str, M: tf2jax.linalg.LinearOperator, **kwargs
) -> SplitPreconditioner:
"""Return the preconditioner of the given type for the given matrix."""
if preconditioner_name == 'auto':
n = M.shape[-1]
rank = kwargs.get('rank', 20)
if 5 * rank >= n:
preconditioner_name = 'partial_cholesky_split'
else:
preconditioner_name = 'truncated_svd'
preconditioner_name = resolve_preconditioner(
preconditioner_name, M, kwargs.get('rank', 20))
try:
return PRECONDITIONER_REGISTRY[preconditioner_name](M, **kwargs)
except KeyError as key_error:
Expand Down

0 comments on commit be4732f

Please sign in to comment.