diff --git a/tensorflow_probability/python/experimental/fastgp/BUILD b/tensorflow_probability/python/experimental/fastgp/BUILD new file mode 100644 index 0000000000..f956e86750 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/BUILD @@ -0,0 +1,259 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Algorithms for training Gaussian Processes in time less than O(n^3) + +# Placeholder: py_library +# Placeholder: py_test + +licenses(["notice"]) + +package( + # default_applicable_licenses + default_visibility = [ + "//research/probability:__subpackages__", + "//tensorflow_probability:__subpackages__", + # vizier:__subpackages__ dep, + ], +) + +py_library( + name = "mbcg", + srcs = ["mbcg.py"], + deps = [ + # jax dep, + ], +) + +py_test( + name = "mbcg_test", + srcs = ["mbcg_test.py"], + deps = [ + ":mbcg", + # absl/testing:absltest dep, + # jax dep, + # numpy dep, + ], +) + +py_library( + name = "fast_gp", + srcs = ["fast_gp.py"], + deps = [ + ":fast_log_det", + ":mbcg", + ":preconditioners", + # jax dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "fast_gp_test", + srcs = ["fast_gp_test.py"], + shard_count = 3, + deps = [ + ":fast_gp", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "fast_mtgp", + srcs = ["fast_mtgp.py"], + deps = [ + ":fast_gp", + ":fast_log_det", + ":linear_operator_sum", + ":mbcg", + ":preconditioners", + # jax dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "fast_mtgp_test", + srcs = ["fast_mtgp_test.py"], + shard_count = 3, + deps = [ + ":fast_mtgp", + # absl/testing:absltest dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "fast_gprm", + srcs = ["fast_gprm.py"], + deps = [ + ":fast_gp", + ":preconditioners", + ":schur_complement", + # jax dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "fast_gprm_test", + srcs = ["fast_gprm_test.py"], + deps = [ + ":fast_gprm", + # absl/testing:absltest dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/python/internal:test_util.jax", + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "linalg", + srcs = ["linalg.py"], + deps = [ + ":partial_lanczos", + # jax dep, + # jax:experimental_sparse dep, + # jaxtyping dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "linalg_test", + srcs = ["linalg_test.py"], + shard_count = 3, + deps = [ + ":linalg", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "partial_lanczos", + srcs = ["partial_lanczos.py"], + deps = [ + ":mbcg", + # jax dep, + # scipy dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "partial_lanczos_test", + srcs = ["partial_lanczos_test.py"], + deps = [ + ":partial_lanczos", + # absl/testing:absltest dep, + # jax dep, + # numpy dep, + ], +) + +py_library( + name = "fast_log_det", + srcs = ["fast_log_det.py"], + deps = [ + ":mbcg", + ":partial_lanczos", + ":preconditioners", + # jax dep, + # jaxtyping dep, + # numpy dep, + # scipy dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "fast_log_det_test", + srcs = ["fast_log_det_test.py"], + shard_count = 3, + deps = [ + ":fast_log_det", + ":preconditioners", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "linear_operator_sum", + srcs = ["linear_operator_sum.py"], + deps = [ + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "preconditioners", + srcs = ["preconditioners.py"], + deps = [ + ":linalg", + ":linear_operator_sum", + # jax dep, + # jax:experimental_sparse dep, + # jaxtyping dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "preconditioners_test", + srcs = ["preconditioners_test.py"], + shard_count = 3, + deps = [ + ":preconditioners", + # absl/testing:absltest dep, + # absl/testing:parameterized dep, + # jax dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_library( + name = "schur_complement", + srcs = ["schur_complement.py"], + deps = [ + ":preconditioners", + # jax dep, + "//tensorflow_probability/substrates:jax", + ], +) + +py_test( + name = "schur_complement_test", + srcs = ["schur_complement_test.py"], + deps = [ + ":schur_complement", + # absl/testing:absltest dep, + # jax dep, + # numpy dep, + "//tensorflow_probability/substrates:jax", + ], +) diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp.py b/tensorflow_probability/python/experimental/fastgp/fast_gp.py new file mode 100644 index 0000000000..79ea875948 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp.py @@ -0,0 +1,418 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Fast likelihoods etc. for Gaussian Processes. + +It's recommended to use `GaussianProcess` in `float64` mode only. +""" + +import dataclasses +import functools + +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.fastgp import fast_log_det +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.substrates import jax as tfp + +tfb = tfp.bijectors +tfd = tfp.distributions +jtf = tfp.tf2jax +Array = jnp.ndarray + +LOG_TWO_PI = 1.8378770664093453 + + +@dataclasses.dataclass +class GaussianProcessConfig: + """Configuration for distributions in the FastGP family.""" + + # The maximum number of iterations to run conjugate gradients + # for when calculating the yt_inv_y part of the log prob. + cg_iters: int = 20 + # The name of a preconditioner in the preconditioner.PRECONDITIONER_REGISTRY + # or 'auto' which will used truncated_svd when n is large and + # partial_cholesky_split when is small. + preconditioner: str = 'auto' + # Use a preconditioner based on a low rank + # approximation of this rank. Note that not all preconditioners have + # adjustable ranks. + preconditioner_rank: int = 20 + # Some preconditioners (like `truncated_svd`) can + # get better approximation accuracy for running for more iterations (even + # at a fixed rank size). This parameter controls that. Note that the + # cost of computing the preconditioner can be as much as + # O(preconditioner_rank * preconditioner_num_iters * n^2) + preconditioner_num_iters: int = 5 + # If "true", compute the preconditioner before any diagonal terms are added + # to the covariance. If "false", compute the preconditioner on the sum of + # the original covariance plus the diagonal terms. If "auto", compute the + # preconditioner before the diagonal terms for scaling preconditioners, + # and after the diagonal terms for all other preconditioners. + precondition_before_jitter: str = 'auto' + # Either `normal`, `normal_qmc`, `normal_orthogonal` or + # `rademacher`. `normal_qmc` is only valid for n <= 10000 + probe_vector_type: str = 'normal_orthogonal' + # The number of probe vectors to use when estimating the log det. + num_probe_vectors: int = 30 + # One of 'slq' (for stochastic Lancos quadrature) or + # 'r1', 'r2', 'r3', 'r4', 'r5', or 'r6' for the rational function + # approximation of the given order. + log_det_algorithm: str = 'r3' + # The number of iterations to use when doing solves inside the log det + # algorithm. + log_det_iters: int = 20 + + +class GaussianProcess(tfp.distributions.Distribution): + """Fast, JAX-only implementation of a GP distribution class. + + See tfd.distributions.GaussianProcess for a description and parameter + documentation. Currently only supports log_prob and posterior_predictive + (the only two methods used by smc.py). + + The default parameters are tuned to give a good time / error trade-off + in the n > 15,000 regime where this class gives a substantial speed-up + over tfd.distributions.GaussianProcess. In particular, it is tuned to + give a trade-off in the case where you care about the accuracy of both + log_prob and its derivative. If you care only about log_prob, it is + recommended to use log_det_algorithm='slq' with preconditioner_num_iters=1. + """ + + def __init__( + self, + kernel, + index_points=None, + mean_fn=None, + observation_noise_variance=0.0, + jitter=1e-6, + config=GaussianProcessConfig(), + validate_args=False, + allow_nan_stats=False, + ): + """Instantiate a fast GaussianProcess distribution. + + Args: + kernel: A `PositiveSemidefiniteKernel`-like instance representing the GP's + covariance function. + index_points: Tensor specifying the points over which the GP is defined. + mean_fn: Python callable that acts on index_points. Default `None` + implies a constant zero mean function. + observation_noise_variance: `float` `Tensor` representing the scalar + variance of the noise in the Normal likelihood distribution of the + model. + jitter: `float` scalar `Tensor` that gets added to the diagonal of the + GP's covariance matrix to ensure it is positive definite. + config: `GaussianProcessConfig` to control speed and quality of GP + approximations. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. Default value: `False`. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or more + of the statistic's batch members are undefined. Default value: `False`. + """ + parameters = dict(locals()) + if jax.tree_util.treedef_is_leaf( + jax.tree_util.tree_structure(kernel.feature_ndims)): + # If the index points are not nested, we assume they are of the same + # float dtype as the GP. + dtype = tfp.internal.dtype_util.common_dtype( + {'index_points': index_points, + 'observation_noise_variance': observation_noise_variance, + 'jitter': jitter}, + jnp.float32) + else: + dtype = tfp.internal.dtype_util.common_dtype( + {'observation_noise_variance': observation_noise_variance, + 'jitter': jitter}, + jnp.float32) + + self._kernel = kernel + self._index_points = index_points + self._mean_fn = tfd.internal.stochastic_process_util.maybe_create_mean_fn( + mean_fn, dtype + ) + self._observation_noise_variance = ( + tfp.internal.tensor_util.convert_nonref_to_tensor( + observation_noise_variance + ) + ) + self._jitter = jitter + self._config = config + self._probe_vector_type = fast_log_det.ProbeVectorType[ + config.probe_vector_type.upper()] + self._log_det_fn = fast_log_det.get_log_det_algorithm( + config.log_det_algorithm + ) + + super(GaussianProcess, self).__init__( + dtype=dtype, + reparameterization_type=tfd.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + name='GaussianProcess') + + @property + def kernel(self): + return self._kernel + + @property + def index_points(self): + return self._index_points + + @property + def mean_fn(self): + return self._mean_fn + + @property + def observation_noise_variance(self): + return self._observation_noise_variance + + @property + def jitter(self): + return self._jitter + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + return dict( + index_points=tfp.util.ParameterProperties( + event_ndims=lambda self: jax.tree_util.tree_map( # pylint: disable=g-long-lambda + lambda nd: nd + 1, self.kernel.feature_ndims), + shape_fn=tfp.internal.parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, + ), + kernel=tfp.util.BatchedComponentProperties(), + observation_noise_variance=( + tfp.util.ParameterProperties( + event_ndims=0, + shape_fn=lambda sample_shape: sample_shape[:-1], + default_constraining_bijector_fn=( + lambda: tfb.Softplus( # pylint: disable=g-long-lambda + low=tfp.internal.dtype_util.eps(dtype)))))) + + @property + def event_shape(self): + return tfd.internal.stochastic_process_util.event_shape( + self._kernel, self.index_points) + + def _mean(self): + mean = self._mean_fn(self._index_points) + mean = jnp.broadcast_to(mean, self.event_shape) + return mean + + def _covariance(self): + index_points = self._index_points + _, covariance = ( + tfd.internal.stochastic_process_util.get_loc_and_kernel_matrix( + kernel=self.kernel, + mean_fn=self._mean_fn, + observation_noise_variance=self.observation_noise_variance, + index_points=index_points)) + return covariance + + def _variance(self): + index_points = self._index_points + kernel_diag = self.kernel.apply(index_points, index_points, example_ndims=1) + return kernel_diag + self.observation_noise_variance[..., jnp.newaxis] + + def _log_det(self, key, is_missing=None): + """Returns log_det, loc, covariance and preconditioner.""" + key1, key2 = jax.random.split(key) + + # TODO(thomaswc): Considering caching loc and covariance for a given + # is_missing. + loc, covariance = ( + tfd.internal.stochastic_process_util.get_loc_and_kernel_matrix( + kernel=self._kernel, + mean_fn=self._mean_fn, + observation_noise_variance=self.dtype(0.), + index_points=self._index_points, + is_missing=is_missing, + mask_loc=False, + ) + ) + + is_scaling_preconditioner = self._config.preconditioner.endswith('scaling') + precondition_before_jitter = ( + self._config.precondition_before_jitter == 'true' + or ( + self._config.precondition_before_jitter == 'auto' + and is_scaling_preconditioner + ) + ) + + def get_preconditioner(cov): + scaling = None + if is_scaling_preconditioner: + scaling = self._observation_noise_variance + self._jitter + return preconditioners.get_preconditioner( + self._config.preconditioner, + cov, + key=key1, + rank=self._config.preconditioner_rank, + num_iters=self._config.preconditioner_num_iters, + scaling=scaling) + + if precondition_before_jitter: + preconditioner = get_preconditioner(covariance) + + updated_diagonal = ( + jnp.diag(covariance) + self._jitter + self.observation_noise_variance) + if is_missing is not None: + updated_diagonal = jnp.where( + is_missing, self.dtype(1. + self._jitter), updated_diagonal) + + covariance = (covariance * ( + 1 - jnp.eye( + updated_diagonal.shape[-1], dtype=updated_diagonal.dtype) + ) + jnp.diag(updated_diagonal)) + + if not precondition_before_jitter: + preconditioner = get_preconditioner(covariance) + + det_term = self._log_det_fn( + covariance, + preconditioner, + key=key2, + num_probe_vectors=self._config.num_probe_vectors, + probe_vector_type=self._probe_vector_type, + num_iters=self._config.log_det_iters, + ) + + return det_term, loc, covariance, preconditioner + + @jax.named_call + def log_prob(self, value, key, is_missing=None) -> Array: + """log P(value | GP).""" + empty_sample_batch_shape = value.ndim == 1 + if empty_sample_batch_shape: + value = value[jnp.newaxis] + if value.ndim != 2: + raise ValueError( + 'fast_gp.GaussianProcess.log_prob only supports values of rank 1 or ' + f'2, got rank {value.ndim} instead.' + ) + + num_unmasked_dims = value.shape[-1] + + det_term, loc, covariance, preconditioner = self._log_det(key, is_missing) + + centered_value = value - loc + if is_missing is not None: + centered_value = jnp.where(is_missing, 0.0, centered_value) + num_unmasked_dims = num_unmasked_dims - jnp.count_nonzero( + is_missing, axis=-1 + ) + + exp_term = yt_inv_y( + covariance, + preconditioner.full_preconditioner(), + jnp.transpose(centered_value), + max_iters=self._config.cg_iters, + ) + + lp = -0.5 * (LOG_TWO_PI * num_unmasked_dims + det_term + exp_term) + if empty_sample_batch_shape: + return jnp.squeeze(lp, axis=0) + + return lp + + def posterior_predictive( + self, + observations, + predictive_index_points=None, + observations_is_missing=None, + **kwargs + ): + # TODO(thomaswc): Speed this up, if possible. + return tfd.GaussianProcessRegressionModel.precompute_regression_model( + kernel=self._kernel, + observation_index_points=self._index_points, + observations=observations, + observations_is_missing=observations_is_missing, + index_points=predictive_index_points, + observation_noise_variance=self._observation_noise_variance, + mean_fn=self._mean_fn, + cholesky_fn=None, + jitter=self._jitter, + **kwargs + ) + + +# pylint: disable=invalid-name + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(3,)) +def yt_inv_y( + kernel: jtf.linalg.LinearOperator, + preconditioner: jtf.linalg.LinearOperator, + y: Array, + max_iters: int = 20, +) -> Array: + """Compute y^t (kernel)^(-1) y. + + Args: + kernel: A matrix or jtf.LinearOperator representing a linear map from R^n to + itself. + preconditioner: An operator on R^n that when applied before kernel, + reduces the condition number of the system. + y: A matrix of shape (n, m). + max_iters: The maximum number of iterations to perform the modified batched + conjugate gradients algorithm for. + + Returns: + y's inner product with itself, with respect to the inverse of the kernel. + """ + def multiplier(B): + return kernel @ B + + inv_y, _ = mbcg.modified_batched_conjugate_gradients( + multiplier, y, preconditioner.solve, max_iters + ) + return jnp.einsum('ij,ij->j', y, inv_y) + + +@yt_inv_y.defjvp +def yt_inv_y_jvp(max_iters, primals, tangents): + """Jacobian-Vector product for yt_inv_y.""" + # According to 2.3.3 of + # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf, + # d(y^t A^(-1) y) = dy^t A^(-1) y - y^t A^(-1) dA A^(-1) y + y^t A^(-1) dy + kernel = primals[0] + preconditioner = primals[1] + y = primals[2] + + dkernel = tangents[0] + dy = tangents[2] + + def multiplier(B): + return kernel @ B + + inv_y, _ = mbcg.modified_batched_conjugate_gradients( + multiplier, y, preconditioner.solve, max_iters + ) + + primal_out = jnp.einsum('ij,ij->j', y, inv_y) + tangent_out = 2.0 * jnp.einsum('ik,ik->k', inv_y, dy) + + if isinstance(dkernel, jtf.linalg.LinearOperator): + tangent_out = tangent_out - jnp.einsum('ik,ik->k', inv_y, dkernel @ inv_y) + else: + tangent_out = tangent_out - jnp.einsum('ik,ij,jk->k', inv_y, dkernel, inv_y) + + return primal_out, tangent_out diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py b/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py new file mode 100644 index 0000000000..c017580570 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp_test.py @@ -0,0 +1,824 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for fast_gp.py.""" + +from absl.testing import parameterized +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import fast_gp +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.substrates import jax as tfp + +from absl.testing import absltest + +tfd = tfp.distributions + + +class _FastGpTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + np.random.seed(5) + self.points = np.random.rand(100, 30).astype(self.dtype) + + def test_gaussian_process_copy(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10,), minval=-1.0, maxval=1.0, + dtype=self.dtype) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + my_gp_copy = my_gp.copy( + config=fast_gp.GaussianProcessConfig(preconditioner_rank=50) + ) + my_gp_params = my_gp.parameters.copy() + my_gp_copy_params = my_gp_copy.parameters.copy() + self.assertNotEqual( + my_gp_params.pop("config"), my_gp_copy_params.pop("config") + ) + self.assertEqual(my_gp.batch_shape, []) + self.assertEqual(my_gp_params, my_gp_copy_params) + + def test_gaussian_process_mean(self): + mean_fn = lambda x: x[:, 0]**2 + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic() + index_points = np.expand_dims( + np.random.uniform(-1., 1., 10).astype(self.dtype), -1) + gp = fast_gp.GaussianProcess(kernel, index_points, mean_fn=mean_fn) + expected_mean = mean_fn(index_points) + np.testing.assert_allclose( + expected_mean, gp.mean(), rtol=1e-5) + + def test_gaussian_process_covariance_and_variance(self): + amp = self.dtype(.5) + len_scale = self.dtype(.2) + jitter = self.dtype(1e-4) + observation_noise_variance = self.dtype(3e-3) + + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(amp, len_scale) + + index_points = np.expand_dims( + np.random.uniform(-1., 1., 10).astype(self.dtype), -1) + + gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=observation_noise_variance, + jitter=jitter) + + def _kernel_fn(x, y): + return amp ** 2 * np.exp(-.5 * (np.squeeze((x - y)**2)) / (len_scale**2)) + + expected_covariance = ( + _kernel_fn(np.expand_dims(index_points, 0), + np.expand_dims(index_points, 1)) + + observation_noise_variance * np.eye(10)) + + np.testing.assert_allclose( + expected_covariance, gp.covariance(), rtol=1e-5) + np.testing.assert_allclose( + np.diag(expected_covariance), gp.variance(), rtol=1e-5) + + def test_gaussian_process_log_prob(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(10,), + minval=-1.0, + maxval=1.0, + dtype=self.dtype + ) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + np.testing.assert_allclose( + my_gp.log_prob(samples, key=jax.random.PRNGKey(1)), + slow_gp.log_prob(samples), + rtol=1e-5, + ) + + @parameterized.parameters( + (fast_gp.GaussianProcessConfig(cg_iters=10), 1.0), + (fast_gp.GaussianProcessConfig(preconditioner="identity"), 1.0), + (fast_gp.GaussianProcessConfig(preconditioner_rank=10), 1.0), + (fast_gp.GaussianProcessConfig(preconditioner_num_iters=10), 1.0), + (fast_gp.GaussianProcessConfig(precondition_before_jitter="true"), 10.0), + (fast_gp.GaussianProcessConfig(precondition_before_jitter="false"), 1.0), + (fast_gp.GaussianProcessConfig(probe_vector_type="rademacher"), 1.0), + (fast_gp.GaussianProcessConfig(num_probe_vectors=20), 1.0), + (fast_gp.GaussianProcessConfig(log_det_algorithm="slq"), 1.0), + (fast_gp.GaussianProcessConfig(log_det_iters=10), 1.0), + ) + def test_gaussian_process_log_prob_with_configs(self, gp_config, delta): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(3,), + minval=-1.0, + maxval=1.0, + dtype=self.dtype, + ) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4), + config=gp_config, + ) + samples = jnp.array([[self.dtype(1.0), self.dtype(2.0), self.dtype(3.0)]]) + lp = my_gp.log_prob(samples, key=jax.random.PRNGKey(1)) + target = -173.0 if self.dtype == np.float32 else -294.0 + self.assertAlmostEqual(lp, target, delta=delta) + + def test_gaussian_process_log_prob_plus_scaling(self): + # Disabled because of b/323368033 + return # EnableOnExport + if self.dtype == np.float32: + self.skipTest("Numerically unstable in Float32.") + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(10,), + minval=-1.0, + maxval=1.0, + dtype=self.dtype + ) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4), + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling") + ) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(2)) + np.testing.assert_allclose( + my_gp.log_prob(samples, key=jax.random.PRNGKey(3)), + slow_gp.log_prob(samples), + rtol=1e-5, + ) + + def test_gaussian_process_log_prob_single_sample(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(10,), + minval=-1.0, + maxval=1.0, + dtype=self.dtype + ) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + single_sample = slow_gp.sample(seed=jax.random.PRNGKey(0)) + lp = my_gp.log_prob(single_sample, key=jax.random.PRNGKey(1)) + self.assertEqual(single_sample.ndim, 1) + self.assertEmpty(lp.shape) + np.testing.assert_allclose( + lp, + slow_gp.log_prob(single_sample), + rtol=1e-5, + ) + + def test_gaussian_process_log_prob2(self): + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10,), minval=-1.0, maxval=1.0 + ).astype(self.dtype) + samples = jnp.array([[ + -0.0980842, + -0.27192444, + -0.22313793, + -0.07691351, + -0.1314459, + -0.2322599, + -0.1493263, + -0.11629149, + -0.34304297, + -0.24659207, + ]]).astype(self.dtype) + + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.), self.dtype(1.), feature_ndims=0) + fgp = fast_gp.GaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + jitter=self.dtype(1e-3), + config=fast_gp.GaussianProcessConfig(preconditioner_rank=10), + ) + sgp = tfd.GaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + jitter=self.dtype(1e-3) + ) + + fast_ll = jnp.sum(fgp.log_prob(samples, key=jax.random.PRNGKey(1))) + slow_ll = jnp.sum(sgp.log_prob(samples)) + + self.assertAlmostEqual(fast_ll, slow_ll, delta=3e-4) + + def test_gaussian_process_log_prob_jits(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10,), minval=-1.0, maxval=1.0 + ).astype(self.dtype) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + my_gp_log_prob = jax.jit(my_gp.log_prob) + np.testing.assert_allclose( + my_gp_log_prob(samples, key=jax.random.PRNGKey(1)), + slow_gp.log_prob(samples), + rtol=1e-5, + ) + + def test_gaussian_process_slq_log_prob_jits(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), + shape=(10,), + minval=-1.0, + maxval=1.0).astype(self.dtype) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4), + config=fast_gp.GaussianProcessConfig(log_det_algorithm="slq"), + ) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4)) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + my_gp_log_prob = jax.jit(my_gp.log_prob) + np.testing.assert_allclose( + my_gp_log_prob(samples, key=jax.random.PRNGKey(1)), + slow_gp.log_prob(samples), + rtol=1e-5, + ) + + def test_gaussian_process_log_prob_with_is_missing(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.1), self.dtype(0.9)) + index_points = jnp.array( + [[-1.0, 0.0], [-0.5, -0.5], [1.5, 0.0], [1.6, 1.5]], + dtype=self.dtype) + my_gp = fast_gp.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4)) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4)) + x = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + is_missing = np.array([True, False, False, True]) + np.testing.assert_allclose( + my_gp.log_prob(x, key=jax.random.PRNGKey(1), is_missing=is_missing), + slow_gp.log_prob(x, is_missing=is_missing), + rtol=1e-6, + ) + + def test_gp_log_prob_hard(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10,), minval=-1.0, maxval=1.0 + ).astype(self.dtype) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4)) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + amplitude=self.dtype(1.), + length_scale=self.dtype(1.), + feature_ndims=0 + ) + fgp = fast_gp.GaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + jitter=self.dtype(1e-3)) + sgp = tfd.GaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + jitter=self.dtype(1e-3)) + fgp_lp = jnp.sum(fgp.log_prob(samples, key=jax.random.PRNGKey(1))) + sgp_lp = jnp.sum(sgp.log_prob(samples)) + + self.assertAlmostEqual(fgp_lp, sgp_lp, places=3) + + def test_gp_log_prob_matern_five_halves(self): + kernel = tfp.math.psd_kernels.MaternFiveHalves( + self.dtype(2.0), self.dtype(1.0)) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 5), minval=-1.0, maxval=1.0 + ).astype(self.dtype) + sgp = tfd.GaussianProcess( + kernel, index_points, observation_noise_variance=self.dtype(0.1) + ) + sample = sgp.sample(1, seed=jax.random.PRNGKey(0)) + fgp = fast_gp.GaussianProcess( + kernel, index_points, observation_noise_variance=self.dtype(0.1) + ) + fgp_lp = jnp.sum(fgp.log_prob(sample, key=jax.random.PRNGKey(1))) + sgp_lp = jnp.sum(sgp.log_prob(sample)) + + self.assertAlmostEqual(fgp_lp, sgp_lp, places=3) + + def test_gaussian_process_log_prob_gradient(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0), feature_ndims=0 + ) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10,), minval=-1.0, maxval=1.0, + dtype=self.dtype) + slow_gp = tfd.GaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4) + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + + def log_prob(amplitude, length_scale, noise, jitter): + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + amplitude, length_scale, feature_ndims=0 + ) + gp = fast_gp.GaussianProcess( + k, index_points, observation_noise_variance=noise, jitter=jitter + ) + return jnp.sum(gp.log_prob(samples, key=jax.random.PRNGKey(1))) + + value, gradient = jax.value_and_grad(log_prob, argnums=[0, 1, 2, 3])( + self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3) + ) + d_amp, d_length_scale, d_noise, d_jitter = gradient + self.assertFalse(jnp.isnan(value)) + self.assertFalse(jnp.isnan(d_amp)) + self.assertFalse(jnp.isnan(d_length_scale)) + self.assertFalse(jnp.isnan(d_noise)) + self.assertFalse(jnp.isnan(d_jitter)) + + def slow_log_prob(amplitude, length_scale, noise, jitter): + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + amplitude, length_scale, feature_ndims=0 + ) + gp = tfd.GaussianProcess( + k, index_points, observation_noise_variance=noise, jitter=jitter + ) + return jnp.sum(gp.log_prob(samples)) + + direct_value = log_prob( + self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3)) + direct_slow_value = slow_log_prob( + self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3)) + self.assertAlmostEqual(direct_value, direct_slow_value, delta=3e-4) + + slow_value, slow_gradient = jax.value_and_grad( + slow_log_prob, argnums=[0, 1, 2, 3] + )(self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3), self.dtype(1e-3)) + self.assertAlmostEqual(value, slow_value, delta=8) + slow_d_amp, slow_d_length_scale, slow_d_noise, slow_d_jitter = slow_gradient + self.assertAlmostEqual(d_amp, slow_d_amp, delta=0.01) + self.assertAlmostEqual(d_length_scale, slow_d_length_scale, delta=0.001) + # TODO(thomaswc): Investigate why the noise gradient is so noisy. + self.assertAlmostEqual(d_noise, slow_d_noise, delta=0.5) + # TODO(thomaswc): Investigate why slow_d_jitter is zero. + self.assertAlmostEqual(d_jitter, slow_d_jitter, delta=1500) + + def test_gaussian_process_log_prob_gradient_of_index_points(self): + samples = jnp.array([ + [-0.7, -0.1, -0.2], + [-0.9, -0.4, -0.5], + [-0.3, -0.2, -0.8], + [-0.3, -0.1, -0.1], + [-0.5, -0.3, -0.6], + ], dtype=self.dtype) + + def fast_log_prob(pt1, pt2, pt3): + index_points = jnp.array([[pt1], [pt2], [pt3]]) + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.1), self.dtype(0.9)) + gp = fast_gp.GaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4)) + lp = gp.log_prob(samples, key=jax.random.PRNGKey(1)) + return jnp.sum(lp) + + def slow_log_prob(pt1, pt2, pt3): + index_points = jnp.array([[pt1], [pt2], [pt3]]) + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.1), self.dtype(0.9)) + gp = tfd.GaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(3e-3), + jitter=self.dtype(1e-4)) + lp = gp.log_prob(samples) + return jnp.sum(lp) + + direct_slow_value = slow_log_prob( + self.dtype(-0.5), self.dtype(0.), self.dtype(0.5)) + direct_fast_value = fast_log_prob( + self.dtype(-0.5), self.dtype(0.), self.dtype(0.5)) + self.assertAlmostEqual(direct_slow_value, direct_fast_value, delta=1e-5) + + slow_value, slow_gradient = jax.value_and_grad( + slow_log_prob, argnums=[0, 1, 2] + )(self.dtype(-0.5), self.dtype(0.), self.dtype(0.5)) + fast_value, fast_gradient = jax.value_and_grad( + fast_log_prob, argnums=[0, 1, 2] + )(self.dtype(-0.5), self.dtype(0.), self.dtype(0.5)) + self.assertAlmostEqual(fast_value, slow_value, places=4) + np.testing.assert_allclose(fast_gradient, slow_gradient, rtol=1e-4) + + def test_yt_inv_y(self): + m = jnp.identity(100).astype(self.dtype) + np.testing.assert_allclose( + fast_gp.yt_inv_y( + m, + preconditioners.IdentityPreconditioner( + m + ).full_preconditioner(), + self.points, + max_iters=20, + ), + 30.0, + rtol=1e2, + ) + + def test_yt_inv_y_hard(self): + m = jnp.array([ + [ + 1.001, + 0.88311934, + 0.9894911, + 0.9695768, + 0.9987461, + 0.98577714, + 0.97863793, + 0.9880289, + 0.7110599, + 0.7718459, + ], + [ + 0.88311934, + 1.001, + 0.9395206, + 0.7564426, + 0.86025584, + 0.94721663, + 0.7791884, + 0.8075757, + 0.9478641, + 0.9758552, + ], + [ + 0.9894911, + 0.9395206, + 1.001, + 0.92534095, + 0.98108065, + 0.9997143, + 0.93953925, + 0.95583755, + 0.79332554, + 0.84795874, + ], + [ + 0.9695768, + 0.7564426, + 0.92534095, + 1.001, + 0.98049456, + 0.91640615, + 0.9991695, + 0.99564964, + 0.5614807, + 0.6257758, + ], + [ + 0.9987461, + 0.86025584, + 0.98108065, + 0.98049456, + 1.001, + 0.97622854, + 0.98763895, + 0.99449164, + 0.6813891, + 0.74358207, + ], + [ + 0.98577714, + 0.94721663, + 0.9997143, + 0.91640615, + 0.97622854, + 1.001, + 0.9313745, + 0.9487237, + 0.80610526, + 0.859435, + ], + [ + 0.97863793, + 0.7791884, + 0.93953925, + 0.9991695, + 0.98763895, + 0.9313745, + 1.001, + 0.99861676, + 0.5861309, + 0.65042824, + ], + [ + 0.9880289, + 0.8075757, + 0.95583755, + 0.99564964, + 0.99449164, + 0.9487237, + 0.99861676, + 1.001, + 0.61803514, + 0.68201244, + ], + [ + 0.7110599, + 0.9478641, + 0.79332554, + 0.5614807, + 0.6813891, + 0.80610526, + 0.5861309, + 0.61803514, + 1.001, + 0.9943819, + ], + [ + 0.7718459, + 0.9758552, + 0.84795874, + 0.6257758, + 0.74358207, + 0.859435, + 0.65042824, + 0.68201244, + 0.9943819, + 1.001, + ], + ]).astype(self.dtype) + ys = jnp.array([ + [ + -0.0980842, + -0.27192444, + -0.22313793, + -0.07691352, + -0.1314459, + -0.2322599, + -0.1493263, + -0.11629149, + -0.34304294, + -0.24659212, + ], + [ + -0.12322001, + -0.23061615, + -0.13245171, + -0.03604657, + -0.18559735, + -0.2970187, + -0.11895001, + -0.03382884, + -0.28200114, + -0.25570437, + ], + [ + -0.18551889, + -0.13777351, + -0.08382752, + -0.17578323, + -0.26691607, + -0.06417686, + -0.22161345, + -0.18164475, + -0.17793402, + -0.22874065, + ], + [ + 0.29383075, + 0.34788758, + 0.31571257, + 0.2702031, + 0.31359673, + 0.32859725, + 0.28001747, + 0.36051235, + 0.5047121, + 0.455843, + ], + [ + -0.47330144, + -0.469457, + -0.42139763, + -0.3552108, + -0.47754064, + -0.47146142, + -0.5066414, + -0.4503611, + -0.5367922, + -0.5307923, + ], + ]).astype(self.dtype) + preconditioner = preconditioners.PartialCholeskySplitPreconditioner(m) + truth = jnp.einsum("ij,ij->j", ys.T, jnp.linalg.solve(m, ys.T)) + quadform = fast_gp.yt_inv_y(m, preconditioner.full_preconditioner(), ys.T) + np.testing.assert_allclose(truth, quadform, rtol=2e-4) + + truth2 = tfp.math.hpsd_quadratic_form_solvevec(m, ys) + np.testing.assert_allclose(truth2, quadform, rtol=2e-4) + + def test_yt_inv_y_derivative(self): + def quadratic(scale): + m = jnp.identity(5).astype(self.dtype) * scale + return fast_gp.yt_inv_y( + m, + preconditioners.IdentityPreconditioner(m).full_preconditioner(), + jnp.array( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=self.dtype)[..., jnp.newaxis], + )[0] + + d = jax.grad(quadratic) + # quadratic(s) = 55/s, quadratic'(s) = -55 / s^2 + self.assertAlmostEqual(d(self.dtype(1.0)), -55.0) + self.assertAlmostEqual(d(self.dtype(2.0)), -55.0/4.0) + + def test_yt_inv_y_derivative_with_diagonal_split_preconditioner(self): + def quadratic(scale): + m = jnp.identity(5).astype(self.dtype) * scale + return fast_gp.yt_inv_y( + m, + preconditioners.DiagonalSplitPreconditioner(m).full_preconditioner(), + jnp.array( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=self.dtype)[..., jnp.newaxis], + )[0] + + d = jax.grad(quadratic) + # quadratic(s) = 55/s, quadratic'(s) = -55 / s^2 + self.assertAlmostEqual(d(self.dtype(1.0)), -55.0) + self.assertAlmostEqual(d(self.dtype(2.0)), -55.0/4.0) + + def test_yt_inv_y_derivative_with_partial_cholesky_preconditioner(self): + def quadratic(scale): + m = jnp.identity(5).astype(self.dtype) * scale + return fast_gp.yt_inv_y( + m, + preconditioners.PartialCholeskyPreconditioner( + m).full_preconditioner(), + jnp.array( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=self.dtype)[..., jnp.newaxis], + )[0] + + d = jax.grad(quadratic) + # quadratic(s) = 55/s, quadratic'(s) = -55 / s^2 + self.assertAlmostEqual(d(self.dtype(1.0)), -55.0) + self.assertAlmostEqual(d(self.dtype(2.0)), -55.0/4.0, delta=5e-4) + + def test_yt_inv_y_derivative_with_rank_one_preconditioner(self): + def quadratic(scale): + m = jnp.identity(5).astype(self.dtype) * scale + return fast_gp.yt_inv_y( + m, + preconditioners.RankOnePreconditioner( + m, key=jax.random.PRNGKey(5)).full_preconditioner(), + jnp.array( + [1.0, 2.0, 3.0, 4.0, 5.0], dtype=self.dtype)[..., jnp.newaxis], + )[0] + + d = jax.grad(quadratic) + # quadratic(s) = 55/s, quadratic'(s) = -55 / s^2 + self.assertAlmostEqual(d(self.dtype(1.0)), -55.0, delta=6e-2) + self.assertAlmostEqual(d(self.dtype(2.0)), -55.0/4.0, delta=2e-2) + + def test_yt_inv_y_derivative_hard(self): + y = jnp.array( + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]], dtype=self.dtype) + b = ( + jnp.diag(jnp.full(10, 2.0)) + + jnp.diag(jnp.full(9, 1.0), 1) + + jnp.diag(jnp.full(9, 1.0), -1) + ).astype(self.dtype) + + def quad_form(jitter): + m = b + jitter * jnp.identity(10).astype(self.dtype) + pc = preconditioners.PartialCholeskySplitPreconditioner(m) + return fast_gp.yt_inv_y(m, pc.full_preconditioner(), y.T)[0] + + d = jax.grad(quad_form) + + def quad_form2(jitter): + m = b + jitter * jnp.identity(10).astype(self.dtype) + return tfp.math.hpsd_quadratic_form_solvevec(m, y)[0] + + d2 = jax.grad(quad_form2) + + self.assertAlmostEqual(d(0.1), d2(0.1), delta=1e-4) + self.assertAlmostEqual(d(1.0), d2(1.0), delta=1e-4) + + +class FastGpTestFloat32(_FastGpTest): + dtype = np.float32 + + +class FastGpTestFloat64(_FastGpTest): + dtype = np.float64 + + +del _FastGpTest + + +if __name__ == "__main__": + config.update("jax_enable_x64", True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gprm.py b/tensorflow_probability/python/experimental/fastgp/fast_gprm.py new file mode 100644 index 0000000000..386089d30e --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_gprm.py @@ -0,0 +1,302 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Fast GPRM.""" + +import jax +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import fast_gp +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.python.experimental.fastgp import schur_complement +from tensorflow_probability.substrates import jax as tfp + +parameter_properties = tfp.internal.parameter_properties +tfd = tfp.distributions +jtf = tfp.tf2jax + + +__all__ = [ + 'GaussianProcessRegressionModel', +] + + +class GaussianProcessRegressionModel(fast_gp.GaussianProcess): + """Fast, JAX-only implementation of a GP distribution class. + + See tfd.distributions.GaussianProcessRegressionModel for a description and + parameter documentation. Note: We assume that the observation index points + and observations are fixed, and so precompute quantities associated with + them. + """ + + def __init__( + self, + kernel, + key: jax.random.PRNGKey, + index_points=None, + observation_index_points=None, + observations=None, + observation_noise_variance=0.0, + predictive_noise_variance=None, + mean_fn=None, + jitter=1e-6, + config=fast_gp.GaussianProcessConfig(), + ): + """Instantiate a fast GaussianProcessRegressionModel instance. + + Args: + kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's + covariance function. + key: `jax.random.PRNGKey` to use when computing the preconditioner. + index_points: (nested) `Tensor` representing finite collection, or batch + of collections, of points in the index set over which the GP is defined. + Shape (of each nested component) has the form `[b1, ..., bB, e, f1, ..., + fF]` where `F` is the number of feature dimensions and must equal + `kernel.feature_ndims` (or its corresponding nested component) and `e` + is the number (size) of index points in each batch. Ultimately this + distribution corresponds to an `e`-dimensional multivariate normal. The + batch shape must be broadcastable with `kernel.batch_shape` and any + batch dims yielded by `mean_fn`. + observation_index_points: (nested) `Tensor` representing finite + collection, or batch of collections, of points in the index set for + which some data has been observed. Shape (of each nested component) has + the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of + feature dimensions and must equal `kernel.feature_ndims` (or its + corresponding nested component), and `e` is the number (size) of index + points in each batch. `[b1, ..., bB, e]` must be broadcastable with the + shape of `observations`, and `[b1, ..., bB]` must be broadcastable with + the shapes of all other batched parameters (`kernel.batch_shape`, + `index_points`, etc). The default value is `None`, which corresponds to + the empty set of observations, and simply results in the prior + predictive model (a GP with noise of variance + `predictive_noise_variance`). + observations: `float` `Tensor` representing collection, or batch of + collections, of observations corresponding to + `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which + must be brodcastable with the batch and example shapes of + `observation_index_points`. The batch shape `[b1, ..., bB]` must be + broadcastable with the shapes of all other batched parameters + (`kernel.batch_shape`, `index_points`, etc.). The default value is + `None`, which corresponds to the empty set of observations, and simply + results in the prior predictive model (a GP with noise of variance + `predictive_noise_variance`). + observation_noise_variance: `float` `Tensor` representing the variance of + the noise in the Normal likelihood distribution of the model. May be + batched, in which case the batch shape must be broadcastable with the + shapes of all other batched parameters (`kernel.batch_shape`, + `index_points`, etc.). Default value: `0.` + predictive_noise_variance: `float` `Tensor` representing the variance in + the posterior predictive model. If `None`, we simply re-use + `observation_noise_variance` for the posterior predictive noise. If set + explicitly, however, we use this value. This allows us, for example, to + omit predictive noise variance (by setting this to zero) to obtain + noiseless posterior predictions of function values, conditioned on noisy + observations. + mean_fn: Python `callable` that acts on `index_points` to produce a + collection, or batch of collections, of mean values at `index_points`. + Takes a (nested) `Tensor` of shape `[b1, ..., bB, e, f1, ..., fF]` and + returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB, e]`. + Default value: `None` implies the constant zero function. + jitter: `float` scalar `Tensor` that gets added to the diagonal of the + GP's covariance matrix to ensure it is positive definite. + config: `GaussianProcessConfig` to control speed and quality of GP + approximations. + + Raises: + ValueError: if either + - only one of `observations` and `observation_index_points` is given, or + - `mean_fn` is not `None` and not callable. + """ + # TODO(srvasude): Add support for masking observations. In addition, cache + # the observation matrix so that it isn't recomputed every iteration. + parameters = dict(locals()) + input_dtype = tfp.internal.dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points, + observation_index_points=observation_index_points, + ), + dtype_hint=tfp.internal.nest_util.broadcast_structure( + kernel.feature_ndims, np.float32)) + + # If the input dtype is non-nested float, we infer a single dtype for the + # input and the float parameters, which is also the dtype of the GP's + # samples, log_prob, etc. If the input dtype is nested (or not float), we + # do not use it to infer the GP's float dtype. + if (not jax.tree_util.treedef_is_leaf( + jax.tree_util.tree_structure(input_dtype)) and + tfp.internal.dtype_util.is_floating(input_dtype)): + dtype = tfp.internal.dtype_util.common_dtype( + dict( + kernel=kernel, + index_points=index_points, + observations=observations, + observation_index_points=observation_index_points, + observation_noise_variance=observation_noise_variance, + predictive_noise_variance=predictive_noise_variance, + jitter=jitter, + ), + dtype_hint=np.float32, + ) + input_dtype = dtype + else: + dtype = tfp.internal.dtype_util.common_dtype( + dict( + observations=observations, + observation_noise_variance=observation_noise_variance, + predictive_noise_variance=predictive_noise_variance, + jitter=jitter, + ), dtype_hint=np.float32) + + if predictive_noise_variance is None: + predictive_noise_variance = observation_noise_variance + if (observation_index_points is None) != (observations is None): + raise ValueError( + '`observations` and `observation_index_points` must both be given ' + 'or None. Got {} and {}, respectively.'.format( + observations, observation_index_points)) + # Default to a constant zero function, borrowing the dtype from + # index_points to ensure consistency. + mean_fn = tfd.internal.stochastic_process_util.maybe_create_mean_fn( + mean_fn, dtype) + + self._observation_index_points = observation_index_points + self._observations = observations + self._observation_noise_variance = observation_noise_variance + self._predictive_noise_variance = predictive_noise_variance + self._jitter = jitter + + covariance = kernel.matrix( + observation_index_points, observation_index_points) + + is_scaling_preconditioner = config.preconditioner.endswith('scaling') + + def get_preconditioner(cov): + scaling = None + if is_scaling_preconditioner: + scaling = self._observation_noise_variance + self._jitter + return preconditioners.get_preconditioner( + config.preconditioner, + cov, + key=key, + rank=config.preconditioner_rank, + num_iters=config.preconditioner_num_iters, + scaling=scaling) + + if is_scaling_preconditioner: + schur_preconditioner = get_preconditioner(covariance) + + updated_diagonal = jnp.diag(covariance) + ( + self._observation_noise_variance + self._jitter) + + covariance = (covariance * ( + 1 - jnp.eye( + updated_diagonal.shape[-1], dtype=updated_diagonal.dtype) + ) + jnp.diag(updated_diagonal)) + + if not is_scaling_preconditioner: + schur_preconditioner = get_preconditioner(covariance) + + conditional_kernel = schur_complement.SchurComplement( + base_kernel=kernel, + preconditioner_fn=schur_preconditioner.full_preconditioner().solve, + fixed_inputs=observation_index_points, + diag_shift=self._observation_noise_variance + self._jitter) + + def conditional_mean_fn(x): + """Conditional mean.""" + k_x_obs = kernel.matrix(x, observation_index_points) + diff = observations - mean_fn(observation_index_points) + k_obs_inv_diff, _ = mbcg.modified_batched_conjugate_gradients( + lambda x: covariance @ x, + diff[..., jnp.newaxis], + preconditioner_fn=schur_preconditioner.full_preconditioner().solve, + max_iters=config.cg_iters, + ) + + return mean_fn(x) + jnp.squeeze(k_x_obs @ k_obs_inv_diff, axis=-1) + + # Special logic for mean_fn only; SchurComplement already handles the + # case of empty observations (ie, falls back to base_kernel). + if not tfd.internal.stochastic_process_util.is_empty_observation_data( + feature_ndims=kernel.feature_ndims, + observation_index_points=observation_index_points, + observations=observations): + tfd.internal.stochastic_process_util.validate_observation_data( + kernel=kernel, + observation_index_points=observation_index_points, + observations=observations) + + super(GaussianProcessRegressionModel, self).__init__( + index_points=index_points, + jitter=jitter, + kernel=conditional_kernel, + mean_fn=conditional_mean_fn, + # What the GP super class calls "observation noise variance" we call + # here the "predictive noise variance". We use the observation noise + # variance for the fit/solve process above, and predictive for + # downstream computations like sampling. + observation_noise_variance=predictive_noise_variance, + config=config, + ) + self._parameters = parameters + + @property + def observation_index_points(self): + return self._observation_index_points + + @property + def observations(self): + return self._observations + + @property + def predictive_noise_variance(self): + return self._predictive_noise_variance + + @classmethod + def _parameter_properties(cls, dtype, num_classes=None): + def _event_ndims_fn(self): + return jax.tree_util.treep_map( + lambda nd: nd + 1, self.kernel.feature_ndims) + return dict( + index_points=tfp.util.ParameterProperties( + event_ndims=_event_ndims_fn, + shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, + ), + observations=tfp.util.ParameterProperties( + event_ndims=1, + shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED), + observation_index_points=tfp.util.ParameterProperties( + event_ndims=_event_ndims_fn, + shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, + ), + observations_is_missing=tfp.util.ParameterProperties( + event_ndims=1, + shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, + ), + kernel=tfp.util.BatchedComponentProperties(), + observation_noise_variance=tfp.util.ParameterProperties( + event_ndims=0, + shape_fn=lambda sample_shape: sample_shape[:-1], + default_constraining_bijector_fn=( + lambda: tfp.bijectors.Softplus( # pylint:disable=g-long-lambda + low=tfp.internal.dtype_util.eps(dtype)))), + predictive_noise_variance=tfp.util.ParameterProperties( + event_ndims=0, + shape_fn=lambda sample_shape: sample_shape[:-1], + default_constraining_bijector_fn=( + lambda: tfp.bijectors.Softplus( # pylint:disable=g-long-lambda + low=tfp.internal.dtype_util.eps(dtype))))) diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gprm_test.py b/tensorflow_probability/python/experimental/fastgp/fast_gprm_test.py new file mode 100644 index 0000000000..0cd217f8f6 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_gprm_test.py @@ -0,0 +1,87 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for fast_gprm.py.""" + +import jax +import numpy as np +from tensorflow_probability.python.experimental.fastgp import fast_gp +from tensorflow_probability.python.experimental.fastgp import fast_gprm +from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.substrates.jax.internal import test_util + +tfd = tfp.distributions + + +class _FastGprmTest(test_util.TestCase): + + def testShapes(self): + # 5x5 grid of index points in R^2 and flatten to 25x2 + index_points = np.linspace(-4., 4., 5) + index_points = np.stack(np.meshgrid(index_points, index_points), axis=-1) + index_points = np.reshape(index_points, [-1, 2]).astype(self.dtype) + # ==> shape = [25, 2] + + seeds = jax.random.split(test_util.test_seed(sampler_type="stateless"), 5) + + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.), self.dtype(1.)) + + observation_noise_variance = self.dtype(1e-6) + observation_index_points = jax.random.uniform( + seeds[0], shape=(7, 2)).astype(self.dtype) + observations = jax.random.uniform(seeds[1], shape=(7,)).astype(self.dtype) + jitter = self.dtype(1e-6) + + dist = fast_gprm.GaussianProcessRegressionModel( + kernel, + seeds[2], + index_points, + observation_index_points, + observations, + observation_noise_variance, + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling", + preconditioner_num_iters=25), + jitter=jitter, + ) + + true_dist = tfd.GaussianProcessRegressionModel( + kernel, + index_points, + observation_index_points, + observations, + observation_noise_variance, + jitter=jitter) + + self.assertEqual(true_dist.event_shape, dist.event_shape) + np.testing.assert_allclose( + dist.mean(), true_dist.mean(), rtol=3e-1, atol=1e-4) + np.testing.assert_allclose( + dist.variance(), true_dist.variance(), rtol=3e-1, atol=1e-3) + + +class FastGprmTestFloat32(_FastGprmTest): + dtype = np.float32 + + +class FastGprmTestFloat64(_FastGprmTest): + dtype = np.float64 + + +del _FastGprmTest + + +if __name__ == "__main__": + test_util.main() diff --git a/tensorflow_probability/python/experimental/fastgp/fast_log_det.py b/tensorflow_probability/python/experimental/fastgp/fast_log_det.py new file mode 100644 index 0000000000..92598d513f --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_log_det.py @@ -0,0 +1,734 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions for quickly computing approximate log det of a big PSD matrix. + +It's recommended to use `fast_log_det` in `float64` mode only. +""" + +import enum +import functools + +import jax +import jax.numpy as jnp +from jaxtyping import Float +import numpy as np +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.python.experimental.fastgp import partial_lanczos +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.substrates import jax as tfp + +jtf = tfp.tf2jax +Array = jnp.ndarray + +# pylint: disable=invalid-name + + +class ProbeVectorType(enum.IntEnum): + RADEMACHER = 0 + NORMAL = 1 + NORMAL_ORTHOGONAL = 2 + NORMAL_QMC = 3 + + +@jax.named_call +def make_probe_vectors( + n: int, + num_probe_vectors: int, + key: jax.Array, + probe_vector_type: ProbeVectorType, + dtype: jnp.dtype, +) -> Array: + """Return num_probe_vectors n-dim random vectors with mean zero.""" + if probe_vector_type == ProbeVectorType.RADEMACHER: + return jax.random.choice( + key, jnp.array([-1.0, 1.0], dtype=dtype), shape=(n, num_probe_vectors) + ) + + if probe_vector_type == ProbeVectorType.NORMAL: + return jax.random.normal(key, shape=(n, num_probe_vectors), dtype=dtype) + + if probe_vector_type == ProbeVectorType.NORMAL_ORTHOGONAL: + if num_probe_vectors > n: + print(f'Warning, make_probe_vectors(normal_orthogonal) called with ' + f'{num_probe_vectors=} > {n=} Falling back on normal.') + return jax.random.normal(key, shape=(n, num_probe_vectors), dtype=dtype) + # Sample a random orthogonal matrix. + key1, key2 = jax.random.split(key) + samples = jax.random.normal(key1, shape=(n, num_probe_vectors), dtype=dtype) + q, _ = jnp.linalg.qr(samples, mode='reduced') + # Rescale up by a chi random variable. + norm = jnp.sqrt( + jax.random.chisquare( + key2, df=n, shape=(num_probe_vectors,), dtype=dtype)) + return q * norm + + if probe_vector_type == ProbeVectorType.NORMAL_QMC: + uniforms = tfp.mcmc.sample_halton_sequence( + dim=n, + num_results=num_probe_vectors, + dtype=dtype, + seed=key) + return jnp.transpose(jax.scipy.special.ndtri(uniforms)) + + raise ValueError( + f'Unknown probe vector type {probe_vector_type}.' + ' Try NORMAL, NORMAL_QMC or RADEMACHER.' + ) + + +@jax.named_call +def log_det_jvp(primals, tangents, f, num_iters): + """Jacobian-Vector product for log_det. + + This can be used to provide a jax.custom_jvp for a function with the + signature (matrix, preconditioner, probe_vectors, key, + *optional_args, num_iters) -> Float. + + [num_iters needs to be last in the signature because it needs to be + a non-diff argument of the custom.jvp, which means it isn't passed in + the primals and we need to put it back in here.] + + Args: + primals: The arguments f was called with. + tangents: The differentials of the primals. + f: The log det function log_det_jvp is the custom.jvp of. + num_iters: The number of iterations to run conjugate gradients for. + + Returns: + The pair (f(primals), df). + """ + primal_out = f(*primals, num_iters=num_iters) + + M = primals[0] + M_dot = tangents[0] + preconditioner = primals[1] + probe_vectors = primals[2] + num_probe_vectors = probe_vectors.shape[-1] + + # d(log det M) = tr(M^(-1) dM) + # Traditionally, this trace is approximated using Hutchinson's trick. + # However, we find that that estimate has a large variance when the + # operator A = M^(-1) dM is badly conditioned. So we use our preconditioner + # P to make B = - P^(-1) dM, and get tr(A) = tr(A+B) - tr(B), with the + # first term coming from Hutchinson's trick with hopefully low variance and + # the second tr(B) term being computed directly. + trace_B = -preconditioner.trace_of_inverse_product(M_dot) + + # tr(A+B) = tr( M^(-1) dM - P^(-1) dM). + # = E[ v^t (M^(-1) dM - P^(-1) dM) v ] + # = E[ (v^t M^(-1) - v^t P^(-1)) (dM v) ] + + # left_factor1 = v^t M^(-1) + left_factor1, _ = mbcg.modified_batched_conjugate_gradients( + lambda x: M @ x, probe_vectors, + preconditioner.full_preconditioner().solve, + max_iters=num_iters + ) + # left_factor2 = v^t P^(-1) + left_factor2 = preconditioner.full_preconditioner().solvevec(probe_vectors.T) + left_factor = left_factor1 - left_factor2.T + # right_factor = dM probe_vectors + right_factor = M_dot @ probe_vectors + + unnormalized_trace_of_A_plus_B = jnp.einsum( + 'ij,ij->', left_factor, right_factor + ) + trace_of_A_plus_B = unnormalized_trace_of_A_plus_B / num_probe_vectors + + tangent_out = trace_of_A_plus_B - trace_B + + return primal_out, tangent_out + + +@jax.named_call +def _log_det_rational_approx_with_hutchinson( + shifts, + coefficients, + bias, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a rational function. + + We calculate log det M as the trace of log M, and we approximate the + trace of log M using Hutchinson's trick. We then approximate + (log M) @ probe_vector using the partial fraction decomposition + log M ~ bias + sum_i coefficients[i] / (M - shifts[i]) + and finally we get the (1 / (M - shifts[i]) ) @ probe_vector parts + using a multishift solver. + + Args: + shifts: An array of length r. When approximating log z with p(x)/q(x), + shifts will contain the roots of q(x). + coefficients: An array of length r. + bias: A scalar Float. + preconditioner: A preconditioner of M. It is used both in speeding the + convergence of the multishift solve, and in reducing the variance of the + approximation used to compute the derivative of this function. + probe_vectors: An array of shape (n, num_probe_vectors). Each probe vector + should be composed of i.i.d. random variables with mean 0 and variance 1. + key: The RNG key. + num_iters: The number of iterations to run the partial Lanczos algorithm + for. + + Returns: + An approximation to the log det of M. + """ + num_probe_vectors = probe_vectors.shape[-1] + + solutions = partial_lanczos.psd_solve_multishift( + preconditioner.preconditioned_operator().matmul, + probe_vectors, + shifts, + key, + num_iters, + ) + # solutions will be (num_shifts, num_probes, n) + weighted_solutions = jnp.einsum('i,ijk->kj', coefficients, solutions) + # logM_pv is our approximation to (log M) @ probe_vectors + logM_pv = bias * probe_vectors + weighted_solutions + + return ( + preconditioner.log_det() + + jnp.einsum('ij,ij->', probe_vectors, logM_pv) / num_probe_vectors + ) + + +R1_SHIFTS = np.array([-1.0], dtype=np.float64) +R1_COEFFICIENTS = np.array([-4.0], dtype=np.float64) + +R2_SHIFTS = np.array( + [-5.828427124746191, -0.1715728752538099], dtype=np.float64 +) +R2_COEFFICIENTS = np.array( + [-23.313708498984763, -0.6862915010152396], dtype=np.float64 +) + +R3_SHIFTS = np.array( + [-13.92820323027551, -1.0, -0.0717967697244908], dtype=np.float64 +) +R3_COEFFICIENTS = np.array( + [-49.52250037431294, -2.2222222222222214, -0.2552774034648563], + dtype=np.float64, +) + +R4_SHIFTS = np.array( + [ + -25.27414236908818, + -2.2398288088435496, + -0.4464626921716892, + -0.03956612989657948, + ], + dtype=np.float64, +) +R4_COEFFICIENTS = np.array( + [ + -91.22640292804368, + -3.861145971009117, + -0.7696381162159669, + -0.1428129847311194, + ], + dtype=np.float64, +) + +R5_SHIFTS = np.array( + [ + -3.9863458189061411e01, + -3.8518399963191827, + -1.0, + -2.5961618368249978e-01, + -2.5085630936916615e-02, + ], + dtype=np.float64, +) +R5_COEFFICIENTS = np.array( + [ + -1.4008241129102026e02, + -6.1858406006156228e00, + -1.2266666666666659e00, + -4.1692913805732562e-01, + -8.8152303639431204e-02, + ], + dtype=np.float64, +) + +R6_SHIFTS = np.array( + [ + -5.7695480540981052e01, + -5.8284271247461907e00, + -1.6983963724170996e00, + -5.8879070648086351e-01, + -1.7157287525380993e-01, + -1.7332380120999309e-02, + ], + dtype=np.float64, +) +R6_COEFFICIENTS = np.array( + [ + -2.0440306874472464e02, + -8.8074009885053552e00, + -1.8333009080451452e00, + -6.3555866838298825e-01, + -2.5926567816131274e-01, + -6.1405012180561776e-02, + ], + dtype=np.float64, +) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) +def _r1( + unused_M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a 1st order rational function.""" + return _log_det_rational_approx_with_hutchinson( + jnp.asarray(R1_SHIFTS, dtype=probe_vectors.dtype), + jnp.asarray(R1_COEFFICIENTS, dtype=probe_vectors.dtype), + 2.0, + preconditioner, + probe_vectors, + key, + num_iters, + ) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) +def _r2( + unused_M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a 2nd order rational function.""" + return _log_det_rational_approx_with_hutchinson( + jnp.asarray(R2_SHIFTS, dtype=probe_vectors.dtype), + jnp.asarray(R2_COEFFICIENTS, dtype=probe_vectors.dtype), + 4.0, + preconditioner, + probe_vectors, + key, + num_iters, + ) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) +def _r3( + unused_M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a 4th order rational function.""" + return _log_det_rational_approx_with_hutchinson( + jnp.asarray(R3_SHIFTS, dtype=probe_vectors.dtype), + jnp.asarray(R3_COEFFICIENTS, dtype=probe_vectors.dtype), + 14.0 / 3.0, + preconditioner, + probe_vectors, + key, + num_iters, + ) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) +def _r4( + unused_M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a 4th order rational function.""" + return _log_det_rational_approx_with_hutchinson( + jnp.asarray(R4_SHIFTS, dtype=probe_vectors.dtype), + jnp.asarray(R4_COEFFICIENTS, dtype=probe_vectors.dtype), + 16.0 / 3.0, + preconditioner, + probe_vectors, + key, + num_iters, + ) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) +def _r5( + unused_M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a 4th order rational function.""" + return _log_det_rational_approx_with_hutchinson( + jnp.asarray(R5_SHIFTS, dtype=probe_vectors.dtype), + jnp.asarray(R5_COEFFICIENTS, dtype=probe_vectors.dtype), + 86.0 / 15.0, + preconditioner, + probe_vectors, + key, + num_iters, + ) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) +def _r6( + unused_M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + key: jax.Array, + num_iters: int, +) -> Float: + """Approximate log det using a 4th order rational function.""" + return _log_det_rational_approx_with_hutchinson( + jnp.asarray(R6_SHIFTS, dtype=probe_vectors.dtype), + jnp.asarray(R6_COEFFICIENTS, dtype=probe_vectors.dtype), + 92.0 / 15.0, + preconditioner, + probe_vectors, + key, + num_iters, + ) + + +@_r1.defjvp +def _r1_jvp(num_iters, primals, tangents): + """Jacobian-Vector product for r1.""" + return log_det_jvp(primals, tangents, _r1, num_iters) + + +@_r2.defjvp +def _r2_jvp(num_iters, primals, tangents): + """Jacobian-Vector product for r2.""" + return log_det_jvp(primals, tangents, _r2, num_iters) + + +@_r3.defjvp +def _r3_jvp(num_iters, primals, tangents): + """Jacobian-Vector product for r3.""" + return log_det_jvp(primals, tangents, _r3, num_iters) + + +@_r4.defjvp +def _r4_jvp(num_iters, primals, tangents): + """Jacobian-Vector product for r4.""" + return log_det_jvp(primals, tangents, _r4, num_iters) + + +@_r5.defjvp +def _r5_jvp(num_iters, primals, tangents): + """Jacobian-Vector product for r5.""" + return log_det_jvp(primals, tangents, _r5, num_iters) + + +@_r6.defjvp +def _r6_jvp(num_iters, primals, tangents): + """Jacobian-Vector product for r6.""" + return log_det_jvp(primals, tangents, _r6, num_iters) + + +@jax.named_call +def r1( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 20, + **unused_kwargs, +) -> Float: + """Approximate log det using a 2nd order rational function.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key1, probe_vector_type, dtype=M.dtype + ) + + return _r1(M, preconditioner, probe_vectors, key2, num_iters) + + +@jax.named_call +def r2( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 20, + **unused_kwargs, +) -> Float: + """Approximate log det using a 2nd order rational function.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key1, probe_vector_type, dtype=M.dtype + ) + + return _r2(M, preconditioner, probe_vectors, key2, num_iters) + + +@jax.named_call +def r3( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 20, + **unused_kwargs, +) -> Float: + """Approximate log det using a 3rd order rational function.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key1, probe_vector_type, dtype=M.dtype + ) + + return _r3(M, preconditioner, probe_vectors, key2, num_iters) + + +@jax.named_call +def r4( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 20, + **unused_kwargs, +) -> Float: + """Approximate log det using a 4th order rational function.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key1, probe_vector_type, dtype=M.dtype + ) + + return _r4(M, preconditioner, probe_vectors, key2, num_iters) + + +@jax.named_call +def r5( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 20, + **unused_kwargs, +) -> Float: + """Approximate log det using a 5th order rational function.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key1, probe_vector_type, dtype=M.dtype + ) + + return _r5(M, preconditioner, probe_vectors, key2, num_iters) + + +@jax.named_call +def r6( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 20, + **unused_kwargs, +) -> Float: + """Approximate log det using a 6th order rational function.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key1, probe_vector_type, dtype=M.dtype + ) + + return _r6(M, preconditioner, probe_vectors, key2, num_iters) + + +def log00(diag: Array, off_diag: Array) -> Array: + """Return the (0, 0)-th entry of the log of the tridiagonal matrix.""" + n = diag.shape[-1] + if n == 1: + return jnp.log(diag[0]) + m = jnp.diag(diag) + jnp.diag(off_diag, -1) + jnp.diag(off_diag, 1) + # Use jax.numpy.linalg.eigh instead of scipy.linalg.eigh_tridiagonal + # because the later isn't yet jax-able. We would use + # jax.scipy.linalg.eigh_tridiagonal, but that doesn't yet return + # eigenvectors. TODO(thomaswc): Switch when it does. + evalues, evectors = jax.numpy.linalg.eigh(m) + log_evalues = jnp.log(evalues) + first_components = evectors[0, :] + return jnp.einsum('i,i,i->', first_components, log_evalues, first_components) + + +@jax.named_call +def batch_log00(ts: mbcg.SymmetricTridiagonalMatrix) -> Array: + """Return the (0, 0)-th entries of the log of the tridiagonal matrices.""" + return jax.vmap(log00)(ts.diag, ts.off_diag) + + +@functools.partial(jax.custom_jvp, nondiff_argnums=(4, 5)) +@functools.partial( + jax.jit, static_argnames=['probe_vectors_are_rademacher', 'num_iters'] +) +def _stochastic_lanczos_quadrature_log_det( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + probe_vectors: Array, + unused_key, + probe_vectors_are_rademacher: bool, + num_iters: int, +) -> Float: + """Fast log det using the alg. from https://arxiv.org/pdf/1809.11165.pdf .""" + n = M.shape[-1] + + _, tridiagonals = mbcg.modified_batched_conjugate_gradients( + lambda x: M @ x, + probe_vectors, + preconditioner.full_preconditioner().solve, + num_iters + # TODO(thomaswc): Pass tolerance=-1 here to make sure mbcg never + # stops early. Currently that is broken. + ) + + # The modified_batched_conjugate_gradients applies a rotation so that the + # tridiagonal matrices are written in a basis where e_0 is parallel to the + # probe vector. Therefore, + # probe_vector^T A probe_vector = (a e_0)^T A (a e_0) + # = a^2 A_00 + # for a = len(probe_vector). + sum_squared_probe_vector_lengths = None + if probe_vectors_are_rademacher: + sum_squared_probe_vector_lengths = n + else: + sum_squared_probe_vector_lengths = jnp.einsum( + 'ij,ij->j', probe_vectors, probe_vectors + ) + trace_log_estimates = sum_squared_probe_vector_lengths * batch_log00( + tridiagonals + ) + + return preconditioner.log_det() + jnp.average(trace_log_estimates) + + +@jax.named_call +def stochastic_lanczos_quadrature_log_det( + M: jtf.linalg.LinearOperator, + preconditioner: preconditioners.Preconditioner, + key: jax.Array, + num_probe_vectors: int = 25, + probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, + num_iters: int = 25, + **unused_kwargs, +) -> Float: + """Fast log det using the alg. from https://arxiv.org/pdf/1809.11165.pdf .""" + n = M.shape[-1] + num_iters = min(n, num_iters) + + probe_vectors = make_probe_vectors( + n, num_probe_vectors, key, probe_vector_type, dtype=M.dtype + ) + + return _stochastic_lanczos_quadrature_log_det( + M, + preconditioner, + probe_vectors, + None, + probe_vector_type == ProbeVectorType.RADEMACHER, + num_iters, + ) + + +@_stochastic_lanczos_quadrature_log_det.defjvp +def _stochastic_lanczos_quadrature_log_det_jvp( + probe_vectors_are_rademacher, num_iters, primals, tangents +): + """Jacobian-Vector product for @_stochastic_lanczos_quadrature_log_det.""" + def slq_f(M, preconditioner, probe_vectors, unused_key, num_iters): + return _stochastic_lanczos_quadrature_log_det( + M, preconditioner, probe_vectors, unused_key, + probe_vectors_are_rademacher, num_iters) + + return log_det_jvp(primals, tangents, slq_f, num_iters) + + +LOG_DET_REGISTRY = { + 'r1': r1, + 'r2': r2, + 'r3': r3, + 'r4': r4, + 'r5': r5, + 'r6': r6, + 'slq': stochastic_lanczos_quadrature_log_det, +} + + +@jax.named_call +def get_log_det_algorithm(alg_name: str): + try: + return LOG_DET_REGISTRY[alg_name] + except KeyError as key_error: + raise ValueError( + 'Unknown algorithm name {}, known log det algorithms are {}'.format( + alg_name, LOG_DET_REGISTRY.keys() + ) + ) from key_error + + +# The below log det algorithms are not yet useful enough to make it into the +# LOG_DET_REGISTRY. + + +def log_det_taylor_series_with_hutchinson( + M: jtf.linalg.LinearOperator, + num_probe_vectors: int, + key: jax.Array, + num_taylor_series_iterations: int = 10, +) -> Float: + """Return an approximation of log det M.""" + # TODO(thomaswc): Consider having this support a batch of LinearOperators. + n = M.shape[0] + A = M - jnp.identity(n) + probe_vectors = jax.random.choice( + key, jnp.array([-1.0, 1.0], dtype=M.dtype), shape=(n, num_probe_vectors)) + estimate = 0.0 + Apv = probe_vectors + sign = 1 + for i in range(1, num_taylor_series_iterations + 1): + Apv = A @ Apv + trace_estimate = 0 + if i == 1: + trace_estimate = jnp.trace(A) + elif i == 2: + # tr A^2 = sum_i (A^2)_ii = sum_i sum_j A_ij A_ji + # = sum_i sum_j A_ij A_ij since A is symmetric. + trace_estimate = jnp.einsum('ij,ij->', A, A) + else: + trace_estimate = ( + jnp.einsum('ij,ij->', Apv, probe_vectors) / num_probe_vectors + ) + estimate = estimate + sign * trace_estimate / i + sign = -sign + + return estimate diff --git a/tensorflow_probability/python/experimental/fastgp/fast_log_det_test.py b/tensorflow_probability/python/experimental/fastgp/fast_log_det_test.py new file mode 100644 index 0000000000..45fc4f2112 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_log_det_test.py @@ -0,0 +1,680 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for fast_log_det.py.""" + +import math + +from absl.testing import parameterized +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import fast_log_det +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.substrates import jax as tfp + +from absl.testing import absltest + +# pylint: disable=invalid-name + + +def rational_at_one(shifts, coefficients): + """Return sum_i coefficients[i] / (1.0 - shifts[i]).""" + n = shifts.shape[-1] + s = 0.0 + for i in range(n): + s += coefficients[i] / (1.0 - shifts[i]) + return s + + +class _FastLogDetTest(parameterized.TestCase): + def test_make_probe_vectors_rademacher(self): + pvs = fast_log_det.make_probe_vectors( + 10, + 5, + jax.random.PRNGKey(0), + fast_log_det.ProbeVectorType.RADEMACHER, + dtype=self.dtype) + for i in range(10): + for j in range(5): + self.assertIn(float(pvs[i, j]), {-1.0, 1.0}) + + @parameterized.parameters( + fast_log_det.ProbeVectorType.NORMAL, + fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL, + fast_log_det.ProbeVectorType.NORMAL_QMC) + def test_make_probe_vectors(self, probe_vector_type): + pvs = fast_log_det.make_probe_vectors( + 10, + 5, + jax.random.PRNGKey(0), + probe_vector_type, + dtype=self.dtype) + for i in range(10): + for j in range(5): + self.assertNotIn(float(pvs[i, j]), {-1.0, 1.0}) + + def test_rational_parameters(self): + self.assertAlmostEqual( + 0.0, + 4.0 + + rational_at_one(fast_log_det.R2_SHIFTS, fast_log_det.R2_COEFFICIENTS), + ) + self.assertAlmostEqual( + 0.0, + 14.0 / 3.0 + + rational_at_one(fast_log_det.R3_SHIFTS, fast_log_det.R3_COEFFICIENTS), + ) + self.assertAlmostEqual( + 0.0, + 16.0 / 3.0 + + rational_at_one(fast_log_det.R4_SHIFTS, fast_log_det.R4_COEFFICIENTS), + ) + self.assertAlmostEqual( + 0.0, + 86.0 / 15.0 + + rational_at_one(fast_log_det.R5_SHIFTS, fast_log_det.R5_COEFFICIENTS), + ) + self.assertAlmostEqual( + 0.0, + 92.0 / 15.0 + + rational_at_one(fast_log_det.R6_SHIFTS, fast_log_det.R6_COEFFICIENTS), + ) + + def test_r2_same_as_rational(self): + num_probe_vectors = 5 + M = jnp.array([[1.0, -0.5, 0.0], [-0.5, 1.0, -0.5], [0.0, -0.5, 1.0]], + dtype=self.dtype) + I = jnp.eye(3, dtype=self.dtype) + pvs = fast_log_det.make_probe_vectors( + 3, + num_probe_vectors, + jax.random.PRNGKey(1), + fast_log_det.ProbeVectorType.RADEMACHER, + dtype=self.dtype + ) + pc = preconditioners.IdentityPreconditioner(M) + _r2_answer = fast_log_det._r2(M, pc, pvs, jax.random.PRNGKey(2), 20) + # r2(z) = 4(z-1)(1+z)/(1 + 6z + z^2) + _rat_answer = ( + 4.0 + * (M - I) + @ (M + I) + @ jnp.linalg.inv(I + 6.0 * M + M @ M) + @ pvs + ) + _rat_answer = jnp.einsum('ij,ij->', pvs, _rat_answer) / float( + num_probe_vectors + ) + np.testing.assert_allclose(_r2_answer, _rat_answer, atol=0.1) + + @parameterized.parameters( + ('r1', 1), + ('r2', 1), + ('r3', 2), + ('r4', 2), + ('r5', 2), + ('r6', 2), + ('slq', 2), + ) + def test_log_det_algorithm_in_low_dim(self, log_det_alg, num_places): + lda = fast_log_det.get_log_det_algorithm(log_det_alg) + + # For 1x1 arrays, for Rademacher probe vectors, v^t A v = A = tr A, so + # only one probe vector is necessary. + m = jnp.array([[1.5]], dtype=self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + log_det = lda( + m, + pc, + jax.random.PRNGKey(0), + 1, + ) + self.assertAlmostEqual(log_det, math.log(1.5), places=num_places) + + m = jnp.array([[0.5]], dtype=self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + log_det = lda( + m, + pc, + jax.random.PRNGKey(1), + 1, + ) + self.assertAlmostEqual(log_det, math.log(0.5), places=num_places) + + m = jnp.array([[1.0, 0.1], [0.1, 1.0]], dtype=self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + log_det = lda( + m, + pc, + jax.random.PRNGKey(2), + num_probe_vectors=200, + ) + self.assertAlmostEqual(log_det, math.log(0.99), places=num_places) + + @parameterized.parameters( + ('r1', 2.0), + ('r2', 1.0), + ('r3', 0.2), + ('r4', 0.1), + ('r5', 0.1), + ('r6', 0.01), + ('slq', 0.00001), + ) + def test_log_det_algorithm_diagonal_matrices(self, log_det_alg, delta): + lda = fast_log_det.get_log_det_algorithm(log_det_alg) + m = jnp.identity(5) + pc = preconditioners.IdentityPreconditioner(m) + self.assertAlmostEqual(lda(m, pc, jax.random.PRNGKey(0)), 0.0, delta=delta) + + m = jnp.diag( + jnp.arange( + 1.0, + 9.0, + ) + ) + pc = preconditioners.IdentityPreconditioner(m) + self.assertAlmostEqual( + lda(m, pc, jax.random.PRNGKey(1)), + # log 8! = + 10.6046029027, + delta=delta, + ) + + @parameterized.parameters( + (fast_log_det.ProbeVectorType.NORMAL, 0.5), + (fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL, 0.4), + (fast_log_det.ProbeVectorType.NORMAL_QMC, 0.9), + (fast_log_det.ProbeVectorType.RADEMACHER, 0.01) + ) + def test_r4_jits_different_probe_vector_types(self, probe_vector_type, delta): + my_log_det = jax.jit( + fast_log_det.r4, static_argnames=[ + 'num_probe_vectors', 'probe_vector_type']) + m = jnp.array([[1.5]], dtype=self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + ld = my_log_det( + m, + pc, + jax.random.PRNGKey(0), + num_probe_vectors=1, + probe_vector_type=probe_vector_type + ) + self.assertAlmostEqual(ld, math.log(1.5), delta=delta) + + m = jnp.array([[0.5]]) + pc = preconditioners.IdentityPreconditioner(m) + self.assertAlmostEqual( + my_log_det( + m, + pc, + jax.random.PRNGKey(1), + num_probe_vectors=1, + ), + math.log(0.5), + delta=delta + ) + + @parameterized.parameters( + ('r1', 1), + ('r2', 1), + ('r3', 3), + ('r4', 3), + ('r5', 3), + ('r6', 3), + ('slq', 3), + ) + def test_log_det_algorithm_jits(self, log_det_alg, num_places): + lda = fast_log_det.get_log_det_algorithm(log_det_alg) + + my_log_det = jax.jit(lda, static_argnames=['num_probe_vectors']) + m = jnp.array([[1.5]], dtype=self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + ld = my_log_det( + m, + pc, + jax.random.PRNGKey(0), + num_probe_vectors=1, + ) + self.assertAlmostEqual(ld, math.log(1.5), places=num_places) + + m = jnp.array([[0.5]]) + pc = preconditioners.IdentityPreconditioner(m) + self.assertAlmostEqual( + my_log_det( + m, + pc, + jax.random.PRNGKey(1), + num_probe_vectors=1, + ), + math.log(0.5), + places=num_places, + ) + + @parameterized.parameters( + ('r1', 5), + ('r2', 5), + ('r3', 5), + ('r4', 5), + ('r5', 5), + ('r6', 5), + ('slq', 5), + ) + def test_log_det_algorithm_derivative(self, log_det_alg, num_places): + lda = fast_log_det.get_log_det_algorithm(log_det_alg) + n = 10 + def log_det_of_scaled_identity(scale): + m = scale * jnp.identity(n).astype(self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + return lda(m, pc, jax.random.PRNGKey(0)) + + d_scale = jax.grad(log_det_of_scaled_identity)(self.dtype(2.0)) + # det(scale I) = scale^n + # log det(scale I) = n log scale + # d log det(scale I) = n (d scale) / scale + # d log det(scale I) / d scale = n / scale + self.assertAlmostEqual(d_scale, n / 2.0, places=num_places) + + # TODO(thomaswc,srvasude): Investigate why these numbers got so much worse + # after cl/564448268. Before that, the deltas were: + # ('r2', 0.3), ('r3', 0.3), ('r4', 0.2), ('r5', 0.2), ('r6', 0.3), + # ('slq', 0.5). + @parameterized.parameters( + ('r1', 0.8), + ('r2', 0.8), + ('r3', 0.7), + ('r4', 0.6), + ('r5', 0.6), + ('r6', 0.6), + ('slq', 1.4), + ) + def test_log_det_grad_of_random(self, log_det_alg, delta): + if self.dtype == np.float32: + self.skipTest('Numerically unstable in float32.') + lda = fast_log_det.get_log_det_algorithm(log_det_alg) + # Generate two random PSD matrices. + A = jax.random.uniform(jax.random.PRNGKey(0), shape=(10, 10), + minval=-1.0, maxval=1.0, dtype=self.dtype) + A = A @ jnp.transpose(A) + B = jax.random.uniform(jax.random.PRNGKey(1), shape=(10, 10), + minval=-1.0, maxval=1.0, dtype=self.dtype) + B = B @ jnp.transpose(B) + + def my_log_det(alpha): + m = A + alpha * B + pc = preconditioners.DiagonalSplitPreconditioner(m) + return lda(m, pc, jax.random.PRNGKey(2)) + + def std_log_det(alpha): + _, logdet = jnp.linalg.slogdet(A + alpha * B) + return logdet + + my_ld, my_grad = jax.value_and_grad(my_log_det)(self.dtype(1.0)) + std_ld, std_grad = jax.value_and_grad(std_log_det)(self.dtype(1.0)) + self.assertAlmostEqual(my_ld, std_ld, delta=delta) + self.assertAlmostEqual(my_grad, std_grad, delta=delta) + + def test_log00(self): + self.assertAlmostEqual( + fast_log_det.log00( + jnp.array([2.0, 3.0], dtype=self.dtype), + jnp.array([1.0], dtype=self.dtype)), + 0.5895146, + places=6, + ) + self.assertAlmostEqual( + fast_log_det.log00( + jnp.array([5.0, 6.0], dtype=self.dtype), + jnp.array([4.0], dtype=self.dtype)), + 1.2035519, + places=6, + ) + + @parameterized.parameters( + (fast_log_det.ProbeVectorType.NORMAL, 1.1, 0.4), + (fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL, 1.7, 0.6), + (fast_log_det.ProbeVectorType.NORMAL_QMC, 0.6, 0.3)) + def test_stochastic_lanczos_quadrature_normal_log_det( + self, probe_vector_type, error_float32, error_float64): + error = error_float32 if self.dtype == np.float32 else error_float64 + m = jnp.identity(5).astype(self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + num_probe_vectors = 25 + if probe_vector_type == fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL: + num_probe_vectors = 5 + self.assertAlmostEqual( + fast_log_det.stochastic_lanczos_quadrature_log_det( + m, + pc, + jax.random.PRNGKey(0), + num_probe_vectors=num_probe_vectors, + probe_vector_type=probe_vector_type, + ), + 0.0, + ) + m = jnp.diag(jnp.arange(1., 9.,).astype(self.dtype)) + if probe_vector_type == fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL: + num_probe_vectors = 8 + pc = preconditioners.IdentityPreconditioner(m) + self.assertAlmostEqual( + fast_log_det.stochastic_lanczos_quadrature_log_det( + m, + pc, + jax.random.PRNGKey(1), + num_probe_vectors=num_probe_vectors, + probe_vector_type=probe_vector_type, + ), + # log 8! = + 10.6046029027, + # At least for this example, normal probe vectors really degrade + # accuracy. + delta=error, + ) + + def test_log_det_gradient_leaks(self): + n = 10 + + def scan_body(scale, _): + m = scale * jnp.identity(n).astype(self.dtype) + pc = preconditioners.IdentityPreconditioner(m) + return fast_log_det.log_det_order_four_rational_with_hutchinson( + m, pc, jax.random.PRNGKey(0)) + + def log_det_of_scaled_identity(scale): + _, y = jax.lax.scan(scan_body, scale, jnp.arange(1)) + return y[0] + + # Because of b/266429021, using nondiff_argnums over inputs that are or + # contain JAX arrays will cause a memory leak when used in a loop like a + # scan. + with self.assertRaises(Exception): + with jax.checking_leaks(): + unused_d_scale = jax.grad(log_det_of_scaled_identity)(2.0) + + @parameterized.parameters('r1', 'r2', 'r3', 'r4', 'r5', 'r6', 'slq') + def test_log_det_gradient_hard(self, algname): + log_det_fn = fast_log_det.get_log_det_algorithm(algname) + b = ( + jnp.diag(jnp.full(10, 2.0)) + + jnp.diag(jnp.full(9, 1.0), 1) + + jnp.diag(jnp.full(9, 1.0), -1) + ).astype(self.dtype) + + def fast_logdet(jitter): + m = b + jitter * jnp.identity(10).astype(self.dtype) + pc = preconditioners.PartialCholeskySplitPreconditioner(m) + return log_det_fn(m, pc, jax.random.PRNGKey(1)) + + def slow_logdet(jitter): + m = b + jitter * jnp.identity(10).astype(self.dtype) + return tfp.math.hpsd_logdet(m) + + d_fast = jax.grad(fast_logdet) + d_slow = jax.grad(slow_logdet) + + self.assertAlmostEqual(d_fast(0.1), d_slow(0.1), places=5) + self.assertAlmostEqual(d_fast(1.0), d_slow(1.0), places=6) + + def test_log_det_jvp(self): + if self.dtype == np.float32: + self.skipTest('Numerically unstable in float32.') + M = ( + jnp.diag(jnp.full(10, 2.0)) + + jnp.diag(jnp.full(9, 1.0), 1) + + jnp.diag(jnp.full(9, 1.0), -1) + ).astype(self.dtype) + M_dot = jnp.diag(jnp.full(10, 0.1).astype(self.dtype)) + pc = preconditioners.PartialCholeskySplitPreconditioner(M) + num_probe_vectors = 25 + probe_vectors = fast_log_det.make_probe_vectors( + 10, + num_probe_vectors, + jax.random.PRNGKey(0), + fast_log_det.ProbeVectorType.RADEMACHER, + dtype=self.dtype, + ) + _, tangent_out = fast_log_det.log_det_jvp( + (M, pc, probe_vectors, num_probe_vectors), + (M_dot, None, None, None), + lambda a, b, c, d, **kwargs: 0.0, + 20, + ) + truth = jnp.trace(jnp.linalg.inv(M) @ M_dot) + self.assertAlmostEqual(tangent_out, truth, places=6) + + @parameterized.parameters( + (fast_log_det.ProbeVectorType.NORMAL, 6), + (fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL, 6), + (fast_log_det.ProbeVectorType.NORMAL_QMC, 5)) + def test_log_det_jvp_normal_probe_vectors( + self, probe_vector_type, places): + M = ( + jnp.diag(jnp.full(10, 2.0)) + + jnp.diag(jnp.full(9, 1.0), 1) + + jnp.diag(jnp.full(9, 1.0), -1) + ).astype(self.dtype) + M_dot = jnp.diag(jnp.full(10, 0.1).astype(self.dtype)) + pc = preconditioners.PartialCholeskySplitPreconditioner(M) + num_probe_vectors = 25 + if probe_vector_type == fast_log_det.ProbeVectorType.NORMAL_ORTHOGONAL: + num_probe_vectors = 10 + probe_vectors = fast_log_det.make_probe_vectors( + 10, + num_probe_vectors, + jax.random.PRNGKey(0), + probe_vector_type, + dtype=self.dtype) + _, tangent_out = fast_log_det.log_det_jvp( + (M, pc, probe_vectors, num_probe_vectors), + (M_dot, None, None, None), + lambda a, b, c, d, **kwargs: 0.0, + 20, + ) + truth = jnp.trace(jnp.linalg.inv(M) @ M_dot) + self.assertAlmostEqual(tangent_out, truth, places=places) + + def test_log_det_jvp_hard(self): + if self.dtype == np.float32: + self.skipTest('Numerically unstable in float32.') + # Example from fast_gp_test.py:test_gaussian_process_log_prob_gradient + M = jnp.array([ + [ + 1.0020001, + 0.88311934, + 0.9894911, + 0.9695768, + 0.9987461, + 0.98577714, + 0.97863793, + 0.9880289, + 0.7110599, + 0.7718459, + ], + [ + 0.88311934, + 1.0020001, + 0.9395206, + 0.7564426, + 0.86025584, + 0.94721663, + 0.7791884, + 0.8075757, + 0.9478641, + 0.9758552, + ], + [ + 0.9894911, + 0.9395206, + 1.0020001, + 0.92534095, + 0.98108065, + 0.9997143, + 0.93953925, + 0.95583755, + 0.79332554, + 0.84795874, + ], + [ + 0.9695768, + 0.7564426, + 0.92534095, + 1.0020001, + 0.98049456, + 0.91640615, + 0.9991695, + 0.99564964, + 0.5614807, + 0.6257758, + ], + [ + 0.9987461, + 0.86025584, + 0.98108065, + 0.98049456, + 1.0020001, + 0.97622854, + 0.98763895, + 0.99449164, + 0.6813891, + 0.74358207, + ], + [ + 0.98577714, + 0.94721663, + 0.9997143, + 0.91640615, + 0.97622854, + 1.0020001, + 0.9313745, + 0.9487237, + 0.80610526, + 0.859435, + ], + [ + 0.97863793, + 0.7791884, + 0.93953925, + 0.9991695, + 0.98763895, + 0.9313745, + 1.0020001, + 0.99861676, + 0.5861309, + 0.65042824, + ], + [ + 0.9880289, + 0.8075757, + 0.95583755, + 0.99564964, + 0.99449164, + 0.9487237, + 0.99861676, + 1.0020001, + 0.61803514, + 0.68201244, + ], + [ + 0.7110599, + 0.9478641, + 0.79332554, + 0.5614807, + 0.6813891, + 0.80610526, + 0.5861309, + 0.61803514, + 1.0020001, + 0.9943819, + ], + [ + 0.7718459, + 0.9758552, + 0.84795874, + 0.6257758, + 0.74358207, + 0.859435, + 0.65042824, + 0.68201244, + 0.9943819, + 1.0020001, + ], + ], dtype=self.dtype) + M_dot = jnp.diag(jnp.full(10, 1.0).astype(self.dtype)) + pc = preconditioners.PartialCholeskySplitPreconditioner(M) + num_probe_vectors = 25 + probe_vectors = fast_log_det.make_probe_vectors( + 10, + num_probe_vectors, + jax.random.PRNGKey(1), + fast_log_det.ProbeVectorType.RADEMACHER, + dtype=self.dtype) + _, tangent_out = fast_log_det.log_det_jvp( + (M, pc, probe_vectors, num_probe_vectors), + (M_dot, None, None, None), + lambda a, b, c, d, **kwargs: 0.0, + 20, + ) + truth = jnp.trace(jnp.linalg.inv(M) @ M_dot) + self.assertAlmostEqual(tangent_out, truth, delta=0.02) + + def test_log_det_taylor_series_with_hutchinson(self): + # For 1x1 arrays, for Rademacher probe vectors, v^t A v = A = tr A, so + # only one probe vector is necessary. + log_det = fast_log_det.log_det_taylor_series_with_hutchinson( + jnp.array([1.5], dtype=self.dtype), 1, jax.random.PRNGKey(0) + ) + self.assertAlmostEqual(log_det, math.log(1.5), places=3) + log_det = fast_log_det.log_det_taylor_series_with_hutchinson( + jnp.array([0.5], dtype=self.dtype), 1, jax.random.PRNGKey(1) + ) + self.assertAlmostEqual(log_det, math.log(0.5), places=3) + log_det = fast_log_det.log_det_taylor_series_with_hutchinson( + jnp.array([[1.0, 0.1], [0.1, 1.0]], dtype=self.dtype), + 200, + jax.random.PRNGKey(2), + ) + self.assertAlmostEqual(log_det, math.log(0.99), places=3) + + def test_log_det_taylor_series_with_hutchinson_order_two_isnt_random(self): + log_det = fast_log_det.log_det_taylor_series_with_hutchinson( + jnp.array([[1.0, 0.1], [0.1, 1.0]], dtype=self.dtype), + 0, + jax.random.PRNGKey(0), + 2, + ) + self.assertAlmostEqual(log_det, math.log(0.99), places=3) + log_det2 = fast_log_det.log_det_taylor_series_with_hutchinson( + jnp.array([[1.0, 0.1], [0.1, 1.0]], dtype=self.dtype), + 0, + jax.random.PRNGKey(1), + 2, + ) + self.assertAlmostEqual(log_det, log_det2, places=8) + + +class FastLogDetTestFloat32(_FastLogDetTest): + dtype = np.float32 + + +class FastLogDetTestFloat64(_FastLogDetTest): + dtype = np.float64 + + +del _FastLogDetTest + + +if __name__ == '__main__': + config.update('jax_enable_x64', True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py b/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py new file mode 100644 index 0000000000..2351621ca3 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py @@ -0,0 +1,255 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Fast likelihoods etc. for Multi-task Gaussian Processeses.""" + +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.fastgp import fast_gp +from tensorflow_probability.python.experimental.fastgp import fast_log_det +from tensorflow_probability.python.experimental.fastgp import linear_operator_sum +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.substrates import jax as tfp + +ps = tfp.internal.prefer_static +tfd = tfp.distributions +tfed = tfp.experimental.distributions +tfek = tfp.experimental.psd_kernels +jtf = tfp.tf2jax +Array = jnp.ndarray + + +def _vec(x): + # Vec takes in a (batch) of matrices of shape B1 + [n, k] and returns + # a (batch) of vectors of shape B1 + [n * k]. + return jnp.reshape(x, ps.concat([ps.shape(x)[:-2], [-1]], axis=0)) + + +def _unvec(x, matrix_shape): + # Unvec takes in a (batch) of matrices of shape B1 + [n * k] and returns + # a (batch) of vectors of shape B1 + [n, k], where n and k are specified + # by matrix_shape. + return jnp.reshape(x, ps.concat([ps.shape(x)[:-1], matrix_shape], axis=0)) + + +class MultiTaskGaussianProcess(tfd.AutoCompositeTensorDistribution): + """Fast, JAX-only implementation of a MTGP distribution class. + + See tfed.distributions.MultiTaskGaussianProcess for a description and + parameter documentation. + """ + + def __init__( + self, + kernel, + index_points=None, + mean_fn=None, + observation_noise_variance=0.0, + config=fast_gp.GaussianProcessConfig(), + validate_args=False, + allow_nan_stats=False): + """Instantiate a fast GaussianProcess distribution. + + Args: + kernel: A `PositiveSemidefiniteKernel`-like instance representing the GP's + covariance function. + index_points: Tensor specifying the points over which the GP is defined. + mean_fn: Python callable that acts on index_points. Default `None` + implies a constant zero mean function. + observation_noise_variance: `float` `Tensor` representing the scalar + variance of the noise in the Normal likelihood distribution of the + model. + config: `GaussianProcessConfig` to control speed and quality of GP + approximations. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. Default value: `False`. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or more + of the statistic's batch members are undefined. Default value: `False`. + + """ + parameters = dict(locals()) + if jax.tree_util.treedef_is_leaf( + jax.tree_util.tree_structure(kernel.feature_ndims)): + # If the index points are not nested, we assume they are of the same + # float dtype as the GP. + dtype = tfp.internal.dtype_util.common_dtype( + {'index_points': index_points, + 'observation_noise_variance': observation_noise_variance}, + jnp.float32) + else: + dtype = tfp.internal.dtype_util.common_dtype( + {'observation_noise_variance': observation_noise_variance}, + jnp.float32) + + self._kernel = kernel + self._index_points = index_points + self._mean_fn = ( + tfd.internal.stochastic_process_util.maybe_create_multitask_mean_fn( + mean_fn, kernel, dtype)) + self._observation_noise_variance = ( + tfp.internal.tensor_util.convert_nonref_to_tensor( + observation_noise_variance + ) + ) + self._config = config + self._probe_vector_type = fast_log_det.ProbeVectorType[ + config.probe_vector_type.upper()] + self._log_det_fn = fast_log_det.get_log_det_algorithm( + config.log_det_algorithm) + + super(MultiTaskGaussianProcess, self).__init__( + dtype=dtype, + reparameterization_type=tfd.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + name='MultiTaskGaussianProcess') + + @property + def kernel(self): + return self._kernel + + @property + def index_points(self): + return self._index_points + + @property + def mean_fn(self): + return self._mean_fn + + @property + def observation_noise_variance(self): + return self._observation_noise_variance + + @property + def event_shape(self): + return tfd.internal.stochastic_process_util.multitask_event_shape( + self._kernel, self.index_points) + + def _mean(self): + loc = self._mean_fn(self._index_points) + return jnp.broadcast_to(loc, self.event_shape) + + def _variance(self): + index_points = self._index_points + kernel_matrix = self.kernel.matrix_over_all_tasks( + index_points, index_points) + observation_noise_variance = self.observation_noise_variance + # We can add the observation noise to each block. + if isinstance(self.kernel, tfek.Independent): + single_task_variance = kernel_matrix.operators[0].diag_part() + if observation_noise_variance is not None: + single_task_variance = ( + single_task_variance + observation_noise_variance[..., jnp.newaxis]) + # Each task has the same variance, so shape this in to an `[..., e, t]` + # shaped tensor and broadcast to batch shape + variance = jnp.stack( + [single_task_variance] * self.kernel.num_tasks, axis=-1) + return variance + + # If `kernel_matrix` has structure, `diag_part` will try to take advantage + # of that structure. In the case of a `Separable` kernel, `diag_part` will + # efficiently compute the diagonal of a kronecker product. + variance = kernel_matrix.diag_part() + if observation_noise_variance is not None: + variance = ( + variance + + observation_noise_variance[..., jnp.newaxis]) + + variance = _unvec(variance, [-1, self.kernel.num_tasks]) + + # Finally broadcast with batch shape. + batch_shape = self._batch_shape_tensor(index_points=index_points) + event_shape = self._event_shape_tensor(index_points=index_points) + + variance = jnp.broadcast_to( + variance, ps.concat([batch_shape, event_shape], axis=0)) + return variance + + @jax.named_call + def log_prob(self, value, key) -> Array: + """log P(value | GP).""" + empty_sample_batch_shape = value.ndim == 2 + if empty_sample_batch_shape: + value = value[jnp.newaxis] + if value.ndim != 3: + raise ValueError( + 'fast_mtgp.MultiTaskGaussianProcess.log_prob only supports values ' + f'of rank 2 or 3, got rank {value.ndim} instead.' + ) + index_points = self._index_points + loc = self.mean() + loc = _vec(loc) + covariance = self.kernel.matrix_over_all_tasks(index_points, index_points) + + centered_value = _vec(value) - loc + key1, key2 = jax.random.split(key) + + is_scaling_preconditioner = self._config.preconditioner.endswith('scaling') + + def get_preconditioner(cov): + scaling = None + if is_scaling_preconditioner: + scaling = self.observation_noise_variance + return preconditioners.get_preconditioner( + self._config.preconditioner, + cov, + key=key1, + rank=self._config.preconditioner_rank, + num_iters=self._config.preconditioner_num_iters, + scaling=scaling) + + if is_scaling_preconditioner: + preconditioner = get_preconditioner(covariance) + + covariance = linear_operator_sum.LinearOperatorSum( + [covariance, + jtf.linalg.LinearOperatorScaledIdentity( + num_rows=covariance.range_dimension, + multiplier=self._observation_noise_variance)]) + + if not is_scaling_preconditioner: + preconditioner = get_preconditioner(covariance) + + # TODO(srvasude): Specialize for Independent and Separable kernels. + # In particular, we should be able to take advantage of the kronecker + # structure, and construct a kronecker preconditioner. + + det_term = self._log_det_fn( + covariance, + preconditioner, + key=key2, + num_probe_vectors=self._config.num_probe_vectors, + probe_vector_type=self._probe_vector_type, + num_iters=self._config.log_det_iters, + ) + + exp_term = fast_gp.yt_inv_y( + covariance, + preconditioner.full_preconditioner(), + jnp.transpose(centered_value), + max_iters=self._config.cg_iters + ) + + lp = -0.5 * ( + fast_gp.LOG_TWO_PI * value.shape[-1] * value.shape[-2] + + det_term + exp_term) + if empty_sample_batch_shape: + return jnp.squeeze(lp, axis=0) + + return lp diff --git a/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py b/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py new file mode 100644 index 0000000000..98e0b0c106 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/fast_mtgp_test.py @@ -0,0 +1,461 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for fast_gp.py.""" + +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import fast_gp +from tensorflow_probability.python.experimental.fastgp import fast_mtgp +from tensorflow_probability.substrates import jax as tfp +from absl.testing import absltest + +jtf = tfp.tf2jax +tfd = tfp.distributions +tfed = tfp.experimental.distributions + + +class _FastMultiTaskGpTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.points = np.random.rand(100, 30).astype(self.dtype) + + def test_gaussian_process_copy(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0) + ) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype) + my_gp = fast_mtgp.MultiTaskGaussianProcess( + kernel, index_points, observation_noise_variance=self.dtype(3e-3) + ) + my_gp_copy = my_gp.copy(config=fast_gp.GaussianProcessConfig( + preconditioner_num_iters=20)) + my_gp_params = my_gp.parameters.copy() + my_gp_copy_params = my_gp_copy.parameters.copy() + self.assertNotEqual(my_gp_params.pop("config"), + my_gp_copy_params.pop("config")) + self.assertEqual(my_gp_params, my_gp_copy_params) + + def test_gaussian_process_log_prob(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0) + ) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype) + my_gp = fast_mtgp.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling") + ) + slow_gp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + np.testing.assert_allclose( + my_gp.log_prob(samples, key=jax.random.PRNGKey(1)), + slow_gp.log_prob(samples), + rtol=2e-3, + ) + + def test_gaussian_process_log_prob_separable(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0) + ) + task_cholesky = np.array([ + [1.4, 0., 0.], + [0.5, 1.23, 0.], + [0.25, 0.3, 1.34]], dtype=self.dtype) + task_cholesky_linop = jtf.linalg.LinearOperatorLowerTriangular( + task_cholesky) + kernel = tfp.experimental.psd_kernels.Separable( + 3, task_kernel_scale_linop=task_cholesky_linop, base_kernel=kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype + ) + my_gp = fast_mtgp.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling") + ) + slow_gp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky, + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + np.testing.assert_allclose( + my_gp.log_prob(samples, key=jax.random.PRNGKey(1)), + slow_gp.log_prob(samples), + rtol=8e-3, + ) + + def test_gaussian_process_log_prob_single_sample(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0) + ) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype + ) + my_gp = fast_mtgp.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling") + ) + slow_gp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky + ) + single_sample = slow_gp.sample(seed=jax.random.PRNGKey(0)) + lp = my_gp.log_prob(single_sample, key=jax.random.PRNGKey(1)) + self.assertEqual(single_sample.ndim, 2) + self.assertEmpty(lp.shape) + np.testing.assert_allclose( + lp, + slow_gp.log_prob(single_sample), + rtol=7e-4, + ) + + def test_gaussian_process_log_prob2(self): + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype) + samples = jnp.array([[ + [-0.0980842, -0.0980842, -0.0980842], + [-0.27192444, -0.27192444, -0.27192444], + [-0.22313793, -0.22313793, -0.22313793], + [-0.07691351, -0.07691351, -0.07691351], + [-0.1314459, -0.1314459, -0.1314459], + [-0.2322599, -0.2322599, -0.2322599], + [-0.1493263, -0.1493263, -0.1493263], + [-0.11629149, -0.11629149, -0.11629149], + [-0.34304297, -0.34304297, -0.34304297], + [-0.24659207, -0.24659207, -0.24659207] + ]]).astype(self.dtype) + + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.0), self.dtype(1.0)) + k = tfp.experimental.psd_kernels.Independent(3, k) + fgp = fast_mtgp.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner_rank=30, + preconditioner="partial_cholesky_plus_scaling"), + ) + sgp = tfed.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + cholesky_fn=jnp.linalg.cholesky + ) + + fast_ll = jnp.sum(fgp.log_prob(samples, key=jax.random.PRNGKey(1))) + slow_ll = jnp.sum(sgp.log_prob(samples)) + np.testing.assert_allclose(fast_ll, slow_ll, rtol=2e-4) + + def test_gaussian_process_log_prob_jits(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0)) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype + ) + my_gp = fast_mtgp.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling"), + ) + slow_gp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + my_gp_log_prob = jax.jit(my_gp.log_prob) + np.testing.assert_allclose( + my_gp_log_prob(samples, key=jax.random.PRNGKey(1)), + slow_gp.log_prob(samples), + rtol=2e-3, + ) + + def test_gp_log_prob_hard(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(1.0)) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype + ) + slow_gp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky, + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + amplitude=self.dtype(1.0), length_scale=self.dtype(1.0)) + task_cholesky = np.array([ + [1.4, 0., 0.], + [0.5, 1.23, 0.], + [0.25, 0.3, 1.34]], dtype=self.dtype) + task_cholesky_linop = jtf.linalg.LinearOperatorLowerTriangular( + task_cholesky) + k = tfp.experimental.psd_kernels.Separable( + 3, task_kernel_scale_linop=task_cholesky_linop, base_kernel=k) + + fgp = fast_mtgp.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner_rank=30, + preconditioner="partial_cholesky_plus_scaling"), + ) + sgp = tfed.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(1e-3), + cholesky_fn=jnp.linalg.cholesky, + ) + fgp_lp = jnp.sum(fgp.log_prob(samples, key=jax.random.PRNGKey(1))) + sgp_lp = jnp.sum(sgp.log_prob(samples)) + np.testing.assert_allclose(fgp_lp, sgp_lp, rtol=3e-4) + + def test_gp_log_prob_matern_five_halves(self): + kernel = tfp.math.psd_kernels.MaternFiveHalves( + self.dtype(2.0), self.dtype(1.0)) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 5), minval=-1.0, maxval=1.0, + dtype=self.dtype + ) + sgp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(0.1), + cholesky_fn=jnp.linalg.cholesky, + ) + sample = sgp.sample(1, seed=jax.random.PRNGKey(0)) + fgp = fast_mtgp.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(0.1), + config=fast_gp.GaussianProcessConfig( + preconditioner_rank=30, + preconditioner="partial_cholesky_plus_scaling"), + ) + fgp_lp = jnp.sum(fgp.log_prob(sample, key=jax.random.PRNGKey(1))) + sgp_lp = jnp.sum(sgp.log_prob(sample)) + np.testing.assert_allclose(fgp_lp, sgp_lp, rtol=1e-5) + + def test_gaussian_process_log_prob_gradient(self): + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(0.5), self.dtype(2.0)) + kernel = tfp.experimental.psd_kernels.Independent(3, kernel) + index_points = jax.random.uniform( + jax.random.PRNGKey(0), shape=(10, 2), minval=-1.0, maxval=1.0, + dtype=self.dtype + ) + slow_gp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky + ) + samples = slow_gp.sample(5, seed=jax.random.PRNGKey(0)) + + def log_prob(amplitude, length_scale, noise): + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + amplitude, length_scale + ) + k = tfp.experimental.psd_kernels.Independent(3, k) + gp = fast_mtgp.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=noise, + config=fast_gp.GaussianProcessConfig( + preconditioner_rank=30, + preconditioner="partial_cholesky_plus_scaling"), + ) + return jnp.sum(gp.log_prob(samples, key=jax.random.PRNGKey(1))) + + value, gradient = jax.value_and_grad(log_prob, argnums=[0, 1, 2])( + self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) + d_amp, d_length_scale, d_noise = gradient + self.assertFalse(jnp.isnan(value)) + self.assertFalse(jnp.isnan(d_amp)) + self.assertFalse(jnp.isnan(d_length_scale)) + self.assertFalse(jnp.isnan(d_noise)) + + def slow_log_prob(amplitude, length_scale, noise): + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + amplitude, length_scale + ) + k = tfp.experimental.psd_kernels.Independent(3, k) + gp = tfed.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=noise, + cholesky_fn=jnp.linalg.cholesky, + ) + return jnp.sum(gp.log_prob(samples)) + + direct_value = log_prob( + self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) + direct_slow_value = slow_log_prob( + self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) + np.testing.assert_allclose(direct_value, direct_slow_value, rtol=4e-4) + + slow_value, slow_gradient = jax.value_and_grad( + slow_log_prob, argnums=[0, 1, 2] + )(self.dtype(1.0), self.dtype(1.0), self.dtype(1e-3)) + np.testing.assert_allclose(value, slow_value, rtol=4e-4) + slow_d_amp, slow_d_length_scale, slow_d_noise = slow_gradient + np.testing.assert_allclose(d_amp, slow_d_amp, rtol=1e-4) + np.testing.assert_allclose(d_length_scale, slow_d_length_scale, rtol=1e-4) + # TODO(thomaswc): Investigate why the noise gradient is so noisy. + np.testing.assert_allclose(d_noise, slow_d_noise, rtol=1e-4) + + def test_gaussian_process_log_prob_gradient_of_index_points(self): + samples = jnp.array([ + [-0.7, -0.1, -0.2], + [-0.5, -0.3, -0.2], + [-0.3, -0.1, -0.1], + ], dtype=self.dtype) + + def fast_log_prob(pt1, pt2, pt3): + index_points = jnp.array([[pt1], [pt2], [pt3]]) + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.1), self.dtype(0.9)) + k = tfp.experimental.psd_kernels.Independent(3, k) + gp = fast_mtgp.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(3e-3), + config=fast_gp.GaussianProcessConfig( + preconditioner="partial_cholesky_plus_scaling"), + ) + lp = gp.log_prob(samples, key=jax.random.PRNGKey(1)) + return jnp.sum(lp) + + def slow_log_prob(pt1, pt2, pt3): + index_points = jnp.array([[pt1], [pt2], [pt3]]) + k = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(1.1), self.dtype(0.9)) + k = tfp.experimental.psd_kernels.Independent(3, k) + gp = tfed.MultiTaskGaussianProcess( + k, + index_points, + observation_noise_variance=self.dtype(3e-3), + cholesky_fn=jnp.linalg.cholesky, + ) + lp = gp.log_prob(samples) + return jnp.sum(lp) + + direct_slow_value = slow_log_prob( + self.dtype(-0.5), self.dtype(0.0), self.dtype(0.5)) + direct_fast_value = fast_log_prob( + self.dtype(-0.5), self.dtype(0.0), self.dtype(0.5)) + np.testing.assert_allclose(direct_slow_value, direct_fast_value, rtol=3e-5) + + slow_value, slow_gradient = jax.value_and_grad( + slow_log_prob, argnums=[0, 1, 2] + )(self.dtype(-0.5), self.dtype(0.0), self.dtype(0.5)) + + fast_value, fast_gradient = jax.value_and_grad( + fast_log_prob, argnums=[0, 1, 2] + )(self.dtype(-0.5), self.dtype(0.0), self.dtype(0.5)) + np.testing.assert_allclose(fast_value, slow_value, rtol=3e-5) + np.testing.assert_allclose(fast_gradient, slow_gradient, rtol=1e-4) + + def test_gaussian_process_mean(self): + mean_fn = lambda x: jnp.stack([x[:, 0]**2, x[:, 0]**3], axis=-1) + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic() + kernel = tfp.experimental.psd_kernels.Independent(2, kernel) + index_points = np.expand_dims( + np.random.uniform(-1., 1., 10).astype(self.dtype), -1) + gp = fast_mtgp.MultiTaskGaussianProcess( + kernel, index_points, mean_fn=mean_fn) + expected_mean = mean_fn(index_points) + np.testing.assert_allclose( + expected_mean, gp.mean(), rtol=1e-5) + + def test_gaussian_process_variance(self): + amp = self.dtype(.5) + len_scale = self.dtype(.2) + observation_noise_variance = self.dtype(3e-3) + + kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(amp, len_scale) + kernel = tfp.experimental.psd_kernels.Independent(2, kernel) + + index_points = np.expand_dims( + np.random.uniform(-1., 1., 10).astype(self.dtype), -1) + + fast = fast_mtgp.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=observation_noise_variance) + mtgp = tfed.MultiTaskGaussianProcess( + kernel, + index_points, + observation_noise_variance=observation_noise_variance, + cholesky_fn=jnp.linalg.cholesky) + np.testing.assert_allclose( + mtgp.variance(), fast.variance(), rtol=1e-5) + + +class FastMultiTaskGpTestFloat32(_FastMultiTaskGpTest): + dtype = np.float32 + + +class FastMultiTaskGpTestFloat64(_FastMultiTaskGpTest): + dtype = np.float64 + + +del _FastMultiTaskGpTest + + +if __name__ == "__main__": + config.update("jax_enable_x64", True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/linalg.py b/tensorflow_probability/python/experimental/fastgp/linalg.py new file mode 100644 index 0000000000..59b3a30f45 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/linalg.py @@ -0,0 +1,189 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Linear algebra routines to use for preconditioners.""" + +import functools +import jax +import jax.experimental.sparse +import jax.numpy as jnp +from jaxtyping import Float +import numpy as np +from tensorflow_probability.python.experimental.fastgp import partial_lanczos +from tensorflow_probability.substrates import jax as tfp + +jtf = tfp.tf2jax +Array = jnp.ndarray + +# pylint: disable=invalid-name + + +def _matvec(M, x) -> jax.Array: + if isinstance(M, jtf.linalg.LinearOperator): + return M.matvec(x) + return M @ x + + +def largest_eigenvector( + M: jtf.linalg.LinearOperator, key: jax.Array, num_iters: int = 10 +) -> tuple[Float, Array]: + """Returns the largest (eigenvalue, eigenvector) of M.""" + n = M.shape[-1] + v = jax.random.uniform(key, shape=(n,), dtype=M.dtype) + for _ in range(num_iters): + v = _matvec(M, v) + v = v / jnp.linalg.norm(v) + + nv = _matvec(M, v) + eigenvalue = jnp.linalg.norm(nv) + return eigenvalue, v + + +def make_randomized_truncated_svd( + key: jax.Array, + M: jtf.linalg.LinearOperator, + rank: int = 20, + oversampling: int = 10, + num_iters: int = 4) -> tuple[Float, Array]: + """Returns approximate SVD for symmetric `M`.""" + # This is based on: + # N. Halko, P.G. Martinsson, J. A. Tropp + # Finding Structure with Randomness: Probabilistic Algorithms for Constucting + # Approximate Matrix Decompositions + # We use the recommended oversampling parameter of 10 by default. + # https://arxiv.org/pdf/0909.4061.pdf + # It's also recommended to run for 2-4 iterations. + max_rank = min(M.shape[-2:]) + if rank > max_rank + 1: + print( + 'Warning, make_randomized_truncated_svd called ' + f'with {rank=} and {max_rank=}') + rank = max_rank + 1 + p = jax.random.uniform( + key, + shape=M.shape[:-1] + (rank + oversampling,), + dtype=M.dtype, + minval=-1., + maxval=1.) + for _ in range(num_iters): + # We will assume that M is symmetric to avoid a transpose. + q = M @ p + q, _ = jnp.linalg.qr(q) + p = M @ q + p, _ = jnp.linalg.qr(p) + + # SVD of Q*MQ + u, s, _ = jnp.linalg.svd( + jnp.swapaxes(p, -1, -2) @ (M @ p), hermitian=True) + + return (p @ u * jnp.sqrt(s))[..., :rank] + + +def make_partial_lanczos( + key: jax.Array, M: jtf.linalg.LinearOperator, rank: int) -> Array: + """Return low rank approximation to M based on the partial Lancozs alg.""" + n = M.shape[-1] + key1, key2 = jax.random.split(key) + v = jax.random.uniform( + key1, shape=(n, 1), minval=-1.0, maxval=1.0, dtype=M.dtype) + Q, T = partial_lanczos.partial_lanczos(lambda x: M @ x, v, key2, rank) + + # Now diagonalize T as R^t D R. + full_T = ( + jnp.diag(T.diag[0, :]) + + jnp.diag(T.off_diag[0, :], 1) + + jnp.diag(T.off_diag[0, :], -1) + ) + # TODO(thomaswc): When jnp.linalg includes eigh_tridiagonal, replace this + # with that. + evalues, evectors = jnp.linalg.eigh(full_T) + sqrt_evalues = jnp.sqrt(evalues) + + # M ~ F^t F, where F = sqrt(D) R Q. + F = jnp.einsum('i,ij->ij', sqrt_evalues, evectors @ Q[0]) + low_rank = jnp.transpose(F) + + return low_rank + + +def make_truncated_svd( + key, M: jtf.linalg.LinearOperator, rank: int, num_iters: int) -> Array: + """Return low rank approximation to M based on the partial SVD alg.""" + n = M.shape[-1] + if 5 * rank >= n: + print( + f'Warning, make_truncated_svd called with {rank=} and {n=}') + rank = int((n-1) / 5) + if rank > 0: + X = jax.random.uniform( + key, shape=(n, rank), minval=-1.0, maxval=1.0, dtype=M.dtype) + evalues, evectors, _ = jax.experimental.sparse.linalg.lobpcg_standard( + M, X, num_iters, None) + low_rank = evectors * jnp.sqrt(evalues) + else: + low_rank = jnp.zeros((n, 0), dtype=M.dtype) + return low_rank + + +@functools.partial(jax.jit, static_argnums=1) +def make_partial_pivoted_cholesky( + M: jtf.linalg.LinearOperator, rank: int) -> Array: + """Return low rank approximation to M based on partial pivoted Cholesky.""" + n = M.shape[-1] + + def swap_row(a, index1, index2): + temp = a[index1] + a = jax.lax.dynamic_update_index_in_dim(a, a[index2], index1, axis=0) + a = jax.lax.dynamic_update_index_in_dim(a, temp, index2, axis=0) + return a + + def body_fn(i, val): + diag, transpositions, permutation, low_rank = val + largest_index = jnp.argmax(diag) + transpositions = jax.lax.dynamic_update_index_in_dim( + transpositions, jnp.array([i, largest_index]), i, axis=0) + diag = swap_row(diag, largest_index, i) + low_rank = swap_row(low_rank.T, largest_index, i).T + permutation = swap_row(permutation, largest_index, i) + + pivot = jnp.sqrt(diag[i]) + row = M[permutation[i], :] + + def reswap_row(index, row): + index, transposition_i = jax.lax.dynamic_index_in_dim( + transpositions, index, 0, keepdims=False) + return swap_row(row.T, index, transposition_i).T + row = jax.lax.fori_loop(0, i + 1, reswap_row, row) + + low_rank_i = jax.lax.dynamic_index_in_dim(low_rank, i, 1, keepdims=False) + low_rank_i = jnp.where( + jnp.arange(n) > i, + (row - jnp.dot(low_rank_i, low_rank)) / pivot, 0.) + low_rank_i = jax.lax.dynamic_update_index_in_dim( + low_rank_i, pivot, i, axis=0) + low_rank = jax.lax.dynamic_update_index_in_dim( + low_rank, low_rank_i, i, axis=0) + diag -= jnp.where(jnp.arange(n) >= i, low_rank_i**2, 0.) + return diag, transpositions, permutation, low_rank + + diag = jnp.diag(M) + _, _, permutation, low_rank = jax.lax.fori_loop( + 0, rank, body_fn, ( + diag, + -jnp.ones([rank, 2], dtype=np.int64), + jnp.arange(n, dtype=np.int64), + jnp.zeros([rank, n], dtype=M.dtype))) + # Invert the permutation + permutation = jnp.argsort(permutation, axis=-1) + return low_rank.T[..., permutation, :] diff --git a/tensorflow_probability/python/experimental/fastgp/linalg_test.py b/tensorflow_probability/python/experimental/fastgp/linalg_test.py new file mode 100644 index 0000000000..c789a6a393 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/linalg_test.py @@ -0,0 +1,88 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test for linalg.py.""" + +from absl.testing import parameterized +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import linalg +import tensorflow_probability.substrates.jax as tfp +from absl.testing import absltest + +jtf = tfp.tf2jax + + +# pylint: disable=invalid-name + + +class _LinalgTest(parameterized.TestCase): + + def test_largest_eigenvector(self): + M = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + evalue, evector = linalg.largest_eigenvector(M, jax.random.PRNGKey(0)) + self.assertAlmostEqual(5.415476, evalue, delta=0.1) + np.testing.assert_allclose( + jnp.array([0.4926988, 0.8701998]), evector, atol=0.1 + ) + + @parameterized.parameters(2, 3, 5, 10, 15) + def test_randomized_truncated_svd(self, n): + # Check that this matches the non-randomized SVD. + A = jax.random.uniform( + jax.random.PRNGKey(1), shape=(n, n), + minval=-1.0, maxval=1.0).astype(self.dtype) + M = A.T @ A + 0.6 * jnp.eye(n).astype(self.dtype) + low_rank = linalg.make_randomized_truncated_svd( + jax.random.PRNGKey(2), M, rank=n, num_iters=2) + self.assertEqual(low_rank.shape, (n, n)) + + @parameterized.parameters(2, 3, 5, 10, 15) + def test_pivoted_cholesky_exact(self, n): + A = jax.random.uniform( + jax.random.PRNGKey(3), shape=(n, n), + minval=-1.0, maxval=1.0).astype(self.dtype) + M = A.T @ A + 0.6 * jnp.eye(n).astype(self.dtype) + + low_rank = linalg.make_partial_pivoted_cholesky(M, rank=n) + self.assertEqual(low_rank.shape, (n, n)) + np.testing.assert_allclose(M, low_rank @ low_rank.T, rtol=5e-6) + + @parameterized.parameters(2, 3, 5, 10, 15) + def test_pivoted_cholesky_approx(self, n): + A = jax.random.uniform( + jax.random.PRNGKey(3), shape=(n, n), + minval=-1.0, maxval=1.0).astype(self.dtype) + M = A.T @ A + 0.6 * jnp.eye(n).astype(self.dtype) + + low_rank = linalg.make_partial_pivoted_cholesky(M, rank=n // 2) + self.assertEqual(low_rank.shape, (n, n // 2)) + + +class LinalgTestFloat32(_LinalgTest): + dtype = np.float32 + + +class LinalgTestFloat64(_LinalgTest): + dtype = np.float64 + + +del _LinalgTest + + +if __name__ == "__main__": + config.update("jax_enable_x64", True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py b/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py new file mode 100644 index 0000000000..7c6dd52316 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py @@ -0,0 +1,128 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Expresses a sum of operators.""" + +import jax +from tensorflow_probability.substrates import jax as tfp + +jtf = tfp.tf2jax + + +@jax.tree_util.register_pytree_node_class +class LinearOperatorSum(jtf.linalg.LinearOperator): + """Encapsulates a sum of linear operators.""" + + def __init__(self, + operators, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None, + is_square=None, + name="LinearOperatorSum"): + r"""Initialize a `LinearOperatorSum`. + + A Sum Operator, expressing `(A[0] + A[1] + A[2] + ... A[N])`, where + `A` is a list of operators. + + This is useful to encapsulate a sum of structured operators without + densifying them. + + Args: + operators: `List` of `LinearOperator`s. + is_non_singular: Expect that this operator is non-singular. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. If `diag.dtype` is real, this is auto-set to `True`. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices + is_square: Expect that this operator acts like square [batch] matrices. + name: A name for this `LinearOperator`. + """ + parameters = dict( + operators=operators, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + name=name + ) + if not operators: + raise ValueError("Expected a non-empty list of `operators`.") + self._operators = operators + dtype = operators[0].dtype + for operator in operators: + if operator.dtype != dtype: + raise TypeError( + "Expected every operation in `operators` to have the same " + "dtype.") + + if all(operator.is_self_adjoint for operator in operators): + if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison + raise ValueError( + f"The sum of self-adjoint operators is always " + f"self-adjoint. Expected argument `is_self_adjoint` to be True. " + f"Received: {is_self_adjoint}.") + is_self_adjoint = True + + if all(operator.is_positive_definite for operator in operators): + if is_positive_definite is False: # pylint: disable=g-bool-id-comparison + raise ValueError( + f"The sum of positive-definite operators is always " + f"positive-definite. Expected argument `is_positive_definite` to " + f"be True. Received: {is_positive_definite}.") + is_positive_definite = True + + super(LinearOperatorSum, self).__init__( + dtype=dtype, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + parameters=parameters, + name=name) + + def _shape(self): + return self._operators[0].shape + + def _shape_tensor(self): + return self._operators[0].shape_tensor() + + @property + def operators(self): + return self._operators + + def _matmul(self, x, adjoint=False, adjoint_arg=False): + return sum(o.matmul( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) for o in self.operators) + + def _to_dense(self): + return sum(o.to_dense() for o in self.operators) + + @property + def _composite_tensor_fields(self): + return ("operators",) + + def tree_flatten(self): + return ((self.operators,), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls(*children) + + @property + def _experimental_parameter_ndims_to_matrix_ndims(self): + return {"operators": [0] * len(self.operators)} diff --git a/tensorflow_probability/python/experimental/fastgp/mbcg.py b/tensorflow_probability/python/experimental/fastgp/mbcg.py new file mode 100644 index 0000000000..8a52da62d3 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/mbcg.py @@ -0,0 +1,204 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Modified Batched Conjugate Gradients (mBCG) for JAX.""" + +from typing import Callable, List, NamedTuple, Tuple + +import jax +import jax.numpy as jnp + +Array = jnp.ndarray + + +class SymmetricTridiagonalMatrix(NamedTuple): + """Holds a batch of symmetric, tridiagonal matrices.""" + diag: Array + off_diag: Array + + +def safe_stack(list_of_arrays: List[Array], empty_size: int) -> Array: + """Like jnp.stack, but handles len == 0 or 1.""" + l = len(list_of_arrays) + if l == 0: + return jnp.empty(shape=(empty_size, 0)) + if l == 1: + return list_of_arrays[0][:, jnp.newaxis] + return jnp.stack(list_of_arrays, axis=-1) + + +# pylint: disable=invalid-name + + +@jax.named_call +def modified_batched_conjugate_gradients( + matrix_matrix_multiplier: Callable[[Array], Array], + B: Array, + preconditioner_fn: Callable[[Array], Array], + max_iters: int = 20, + tolerance: float = 1e-6, +) -> Tuple[Array, SymmetricTridiagonalMatrix]: + """Return A^(-1)B and Lanczos tridiagonal matrices. + + Based on Algorithm 2 on page 14 of https://arxiv.org/pdf/1809.11165.pdf + + Args: + matrix_matrix_multiplier: A function for left-matrix multiplying by an n x n + matrix A, which should be symmetric and positive definite. + B: An n x t matrix containing the vectors v_i for which we want A^(-1) v_i. + preconditioner_fn: A function that applies an invertible linear + transformation to its input, designed to increase the rate of convergence + by decreasing the condition number. The preconditioner_fn should Act like + left application of an n by n linear operator, i.e. preconditioner_fn(n x + m) should have shape n x m. Somewhat amazingly, this algorithm doesn't + care if the passed preconditioner is a left, right, or split + preconditioner -- the same output will result. Just be sure to pass in + full_preconditioner.solve when using a SplitPreconditioner. + max_iters: Run conjugate gradients for at most this many iterations. Note + that no matter the value of max_iters, the loop will run for at most n + iterations. + tolerance: Stop early if all the errors have less than this magnitude. + + Returns: + A pair (C, T) where C is the n x t matrix A^(-1) B and T is a + SymmetricTridiagonalMatrix where the diag part is of shape (t, max_iters) + and the off_diag part is of shape (t, max_iters -1). + """ + n, t = B.shape + init_solutions = jnp.zeros_like(B) + # Algorithm 2 has + # current_errors = matrix_matrix_multiplier(current_solutions) - B + # but this leads to the first update being towards -B instead of B, + # which is wrong. + init_errors = B + init_search_directions = jnp.zeros_like(B) + # init_preconditioned_errors doesn't get used because + # init_search_directions is zero, but it needs to be non-zero to avoid + # divide by zero nans. + init_preconditioned_errors = B + max_iters = min(max_iters, n) + + diags = jnp.ones(shape=(t, max_iters), dtype=B.dtype) + # When n = 1, we still need a location to write a dummy value. + off_diags = jnp.zeros(shape=(t, max(1, max_iters - 1)), dtype=B.dtype) + + def loop_body(carry, _): + """Body for jax.lax.while_loop.""" + (j, old_errors, old_solutions, + old_preconditioned_errors, old_search_directions, + old_alpha, beta_factor, diags, off_diags) = carry + preconditioned_errors = preconditioner_fn(old_errors) + + converged = jnp.all(jnp.abs(old_errors) < tolerance, axis=0) + # We check convergence per batch member. + assert converged.shape == (t,) + + # beta is a size t vector, i-th entry = + # preconditioned_errors[:, i] dot preconditioned_errors[:, i] + # ----------------------------------------------------------- + # old_preconditioned_errors[:, i] dot old_preconditioned_errors[:, i] + beta_numerator = jnp.einsum('ij,ij->j', + preconditioned_errors, preconditioned_errors) + beta_denominator = jnp.einsum( + 'ij,ij->j', old_preconditioned_errors, old_preconditioned_errors) + beta = beta_factor * beta_numerator / beta_denominator + safe_beta = jnp.where(converged, 1., beta) + + search_directions = jnp.where( + converged, old_search_directions, preconditioned_errors + + safe_beta[jnp.newaxis] * old_search_directions) + + v = matrix_matrix_multiplier(search_directions) + + # alpha is a size t vector, i-th entry = + # current_errors[:, i] dot preconditioned_errors[:, i] + # ---------------------------------------------------- + # search_directions[:, i] dot v[:, i] + alpha_num = jnp.einsum('ij,ij->j', old_errors, preconditioned_errors) + alpha_denom = jnp.einsum('ij,ij->j', search_directions, v) + alpha = alpha_num / alpha_denom + safe_alpha = jnp.where(converged, 1., alpha) + + new_solutions = jnp.where( + converged, + old_solutions, + old_solutions + safe_alpha[jnp.newaxis] * search_directions) + # TODO(srvasude): Test out the following change: + # new_errors = B - matrix_matrix_multiplier(new_solutions) + # While requiring one more matrix multiplication, this is a more numerically + # stable expression and can be used more reliably as a stopping criterion. + new_errors = jnp.where( + converged, + old_errors, + old_errors - safe_alpha[jnp.newaxis] * v) + + # When j = 0, beta = 0 (because of beta_factor), so old_alpha doesn't + # matter. + # Here and below, use the double-where trick to avoid NaN gradients. + diag_update = jnp.where( + converged, 1., 1. / safe_alpha + safe_beta / old_alpha) + new_diags = diags.at[:, j].set(diag_update) + # When j = 0, beta = 0 (because of beta_factor), so this ends up writing + # a zero vector at the end of off_diags, which is a no-op. + off_diag_update = jnp.where(converged, 0., jnp.sqrt(safe_beta) / old_alpha) + new_off_diags = off_diags.at[:, j - 1].set(off_diag_update) + + # Only update if we are not within tolerance. + (preconditioned_errors, search_directions, alpha) = (jax.tree_map( + lambda o, n: jnp.where(converged, o, n), + (old_preconditioned_errors, old_search_directions, old_alpha), + (preconditioned_errors, search_directions, safe_alpha))) + + new_beta_factor = jnp.array([1.0], dtype=B.dtype) + + return (j + 1, new_errors, new_solutions, + preconditioned_errors, search_directions, alpha, new_beta_factor, + new_diags, new_off_diags), () + + init_alpha = jnp.ones(shape=(t,), dtype=B.dtype) + beta_factor = jnp.array([0.0], dtype=B.dtype) + + scan_out, _ = jax.lax.scan( + loop_body, + (0, init_errors, init_solutions, + init_preconditioned_errors, init_search_directions, + init_alpha, beta_factor, + diags, off_diags), + None, max_iters) + _, _, solutions, _, _, _, _, diags, off_diags = scan_out + + return solutions, SymmetricTridiagonalMatrix(diag=diags, off_diag=off_diags) + + +@jax.jit +def tridiagonal_det(diag: Array, off_diag: Array) -> float: + """Return the determinant of a tridiagonal matrix.""" + # From https://en.wikipedia.org/wiki/Tridiagonal_matrix#Determinant + # TODO(thomaswc): Turn this into a method of SymmetricTridiagonalMatrix. + # Using a scan reduce the number of ops in the graph, therefore reducing + # lowering and compile times. + def scan_body(carry, xs): + d = xs[0] + o = xs[1] + pv = carry[0] + v = carry[1] + new_value = d * v - o**2 * pv + return (v, new_value), new_value + + initial_d = diag[0] + # Return the last value of the determinant recursion. + return jax.lax.scan( + scan_body, + init=(jnp.ones_like(initial_d), initial_d), + xs=(diag[1:], off_diag))[1][-1] diff --git a/tensorflow_probability/python/experimental/fastgp/mbcg_test.py b/tensorflow_probability/python/experimental/fastgp/mbcg_test.py new file mode 100644 index 0000000000..4370dae084 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/mbcg_test.py @@ -0,0 +1,202 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for mbcg.py.""" + +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +import scipy +from tensorflow_probability.python.experimental.fastgp import mbcg +from absl.testing import absltest + +# pylint: disable=invalid-name + + +class _MbcgTest(absltest.TestCase): + + def test_modified_batched_conjugate_gradients(self): + A = jnp.array([[1.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + def multiplier(B): + return A @ B + v = jnp.array([4.0, 6.0], dtype=self.dtype) + w = v[:, jnp.newaxis] + z = multiplier(w) + def identity(x): + return x + inverse, t = mbcg.modified_batched_conjugate_gradients( + multiplier, z, identity) + np.testing.assert_allclose(v, inverse[:, 0], rtol=1e-6) + # The tridiagonal matrix should have approximately the same determinant + # as A. + np.testing.assert_allclose( + mbcg.tridiagonal_det(t.diag[0], t.off_diag[0]), + 3.0, rtol=1e-6) + + # Now, with jit. + def mbcg_pure_tensor(M, B): + return mbcg.modified_batched_conjugate_gradients( + lambda x: M @ x, B, identity) + mbcg_jit = jax.jit(mbcg_pure_tensor) + inverse2, t2 = mbcg_jit(A, z) + np.testing.assert_allclose(v, inverse2[:, 0], rtol=1e-6) + # The tridiagonal matrix should have approximately the same determinant + # as A. + np.testing.assert_allclose( + mbcg.tridiagonal_det(t2.diag[0], t2.off_diag[0]), + 3.0, rtol=1e-6) + + def test_mbcg_scalar(self): + A = jnp.array([[2.0]], dtype=self.dtype) + def multiplier(B): + return A @ B + v = jnp.array([5.0], dtype=self.dtype) + w = v[:, jnp.newaxis] + z = multiplier(w) + def identity(x): + return x + inverse, t = mbcg.modified_batched_conjugate_gradients( + multiplier, z, identity) + np.testing.assert_allclose(v, inverse[:, 0], rtol=1e-6) + np.testing.assert_allclose(t.diag[0][0], 2.0, rtol=1e-6) + + def test_mbcg_three_by_three(self): + A = jnp.array( + [[1.0, 0.0, 1.0], [0.0, 4.0, 0.0], [1.0, 0.0, 6.0]], + dtype=self.dtype) + w = jnp.array([7.0, 8.0, 9.0], dtype=self.dtype)[:, jnp.newaxis] + inverse, t = mbcg.modified_batched_conjugate_gradients( + lambda B: A @ B, A @ w, lambda x: x) + np.testing.assert_allclose(w[:, 0], inverse[:, 0], rtol=1e-6) + np.testing.assert_allclose( + mbcg.tridiagonal_det(t.diag[0], t.off_diag[0]), + 20.0, rtol=1e-6) + + def test_mbcg_identity(self): + W = np.random.rand(10, 5).astype(self.dtype) + inverse, tridiagonals = mbcg.modified_batched_conjugate_gradients( + lambda B: B, W, lambda x: x) + np.testing.assert_allclose(inverse, W) + np.testing.assert_allclose( + tridiagonals.off_diag[0], + jnp.array( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + dtype=self.dtype)) + np.testing.assert_allclose( + tridiagonals.diag[0], + jnp.array( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + dtype=self.dtype)) + + def test_mbcg_diagonal(self): + A = jnp.diag(jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=self.dtype)) + w = jnp.array([1.0, 0.5, 1.5, 2.0, -1.0], dtype=self.dtype)[:, jnp.newaxis] + _, tridiagonals = mbcg.modified_batched_conjugate_gradients( + lambda B: A @ B, A @ w, lambda x: x + ) + evalues, _ = scipy.linalg.eigh_tridiagonal( + tridiagonals.diag[0], tridiagonals.off_diag[0] + ) + np.testing.assert_allclose(jnp.diag(A), sorted(evalues), rtol=1e-6) + + def test_mbcg_diagonal2(self): + A = jnp.diag( + jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=self.dtype)) + w = jnp.array([-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0], + dtype=self.dtype)[:, jnp.newaxis] + _, tridiagonals = mbcg.modified_batched_conjugate_gradients( + lambda B: A @ B, A @ w, lambda x: x + ) + evalues, _ = scipy.linalg.eigh_tridiagonal( + tridiagonals.diag[0], tridiagonals.off_diag[0] + ) + np.testing.assert_allclose(jnp.diag(A), sorted(evalues), rtol=1e-6) + + def test_mbcg_batching(self): + A = jnp.array([[2.0, 1.0], [1.0, 3.0]], dtype=self.dtype) + W = jnp.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], dtype=self.dtype) + P = jnp.array([[0.5, 0.0], [0.0, 1.0/3.0]], dtype=self.dtype) + preconditioner_fn = lambda B: P @ B + inverses, tridiagonals = mbcg.modified_batched_conjugate_gradients( + lambda B: A @ B, A @ W, preconditioner_fn) + np.testing.assert_allclose(inverses, W, atol=0.1) + np.testing.assert_allclose( + jnp.array([[1.29, 0.74], [1.35, 0.66], [1.37, 0.63]], dtype=self.dtype), + tridiagonals.diag, rtol=0.1) + np.testing.assert_allclose( + jnp.array([[0.34], [0.23], [0.18]], dtype=self.dtype), + tridiagonals.off_diag, rtol=0.1) + + def test_mbcg_max_iters(self): + A = jax.random.normal( + jax.random.PRNGKey(2), (30, 30)).astype(self.dtype) + A = A @ A.T + # Ensure it is diagonally dominant. + i, j = jnp.diag_indices(30) + A = A.at[..., i, j].set(0.) + A = A.at[..., i, j].set(10. * jnp.sum(jnp.abs(A), axis=-1)) + # Finally divide by the diagonal to ensure small condition number. + A = A / jnp.diag(A)[..., jnp.newaxis] + + _, t = mbcg.modified_batched_conjugate_gradients( + lambda t: jnp.matmul(A, t), + B=jnp.ones([A.shape[-1], 1], dtype=A.dtype), + max_iters=20, + preconditioner_fn=lambda a: a) + # Ensure that the shape is at most max iterations. + self.assertEqual(t.diag[0].shape[-1], 20) + np.testing.assert_allclose( + mbcg.tridiagonal_det(t.diag[0], t.off_diag[0]), + jnp.exp(jnp.linalg.slogdet(A)[1]), rtol=6.3e-2) + + def test_value_and_grad(self): + def norm1_of_inverse(mat): + return jnp.sum(mbcg.modified_batched_conjugate_gradients( + lambda t: jnp.matmul(mat, t), + B=jnp.ones([mat.shape[-1], 1]), + preconditioner_fn=lambda a: a)[0]) + + v, g = jax.value_and_grad(norm1_of_inverse)(jnp.eye(3, dtype=self.dtype)) + np.testing.assert_allclose(v, [3.0]) + np.testing.assert_allclose(g, jnp.full((3, 3), -1.0)) + + def test_tridiagonal_det(self): + diag = jnp.array([-2, -2, -2, -2], dtype=self.dtype) + off_diag = jnp.array([1, 1, 1], dtype=self.dtype) + np.testing.assert_allclose(mbcg.tridiagonal_det(diag, off_diag), 5) + + # Verified by on-line determinant calculator + np.testing.assert_allclose( + 26.3374, + mbcg.tridiagonal_det(jnp.array([5.724855, 4.1276183, 1.1475269]), + jnp.array([0.6466101, 0.22849184])), + rtol=1e+2 + ) + + +class MbcgTestFloat32(_MbcgTest): + dtype = np.float32 + + +class MbcgTestFloat64(_MbcgTest): + dtype = np.float64 + + +del _MbcgTest + + +if __name__ == '__main__': + config.update('jax_enable_x64', True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py b/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py new file mode 100644 index 0000000000..aa61e7385c --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py @@ -0,0 +1,281 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Run Lanczos for m iterations to get a good preconditioner.""" + +from typing import Callable, Tuple + +import jax +import jax.numpy as jnp +import scipy +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.substrates import jax as tfp + +jtf = tfp.tf2jax +Array = jnp.ndarray + +# pylint: disable=invalid-name + + +def gram_schmidt(Q, v): + """Return a vector after applying the Gram-Schmidt process to it. + + Args: + Q: A tensor of shape (batch_dims, m, n) representing a batch of + orthonormal m-by-n matrices. The `m` vectors to perform Gram-Schmidt + with respect to. + v: Initial starting vector of shape (n, batch_dims). + + Returns: + A tensor t of shape (n, batch_dims) such that for every i, + t[:, i] is orthogonal to Q[i, j, :] for every j. + """ + # This will be used several times to do twice is enough reorthogonalization, + # after initial orthogonalization. + # See: + # [1] L. Giruad, J. Langou, M. Rozloznik, On the round-offf error analysis of + # the Gram-Schmidt Algorithm with reorthogonalization. + correction = Q @ v.T[..., jnp.newaxis] + correction = jnp.squeeze(jnp.swapaxes(Q, -1, -2) @ correction, axis=-1) + return v - correction.T + + +def reorthogonalize(Q, v): + for _ in range(3): + v = gram_schmidt(Q, v) + return v + + +@jax.named_call +def partial_lanczos( + multiplier: Callable[[Array], Array], + v: Array, + key: jax.Array, + num_iters: int = 20) -> Tuple[Array, mbcg.SymmetricTridiagonalMatrix]: + """Returns orthonormal Q and tridiagonal T such that A ~ Q^t T Q. + + Similar to modified_batched_conjugate_gradients, but this returns the + Q matrix that performs the tridiagonalization. (But unlike mbcg, this + doesn't support preconditioning.) + + Args: + multiplier: A function for left matrix multiplying by an n by n + symmetric positive definite matrix A. + v: A tensor of shape (n, k) representing a batch of k n-dimensional + vectors. Each vector v[:, i] will be used to form a Krylov subspace + spanned by A^j v[:, i] for j in [0, num_iters). + key: RNG generator used to initialize the model's parameters. + num_iters: The number of iterations to run the partial Lanczos algorithm + for. + + Returns: + A pair (Q, T) where Q is a shape (k, num_iters, n) batch of k + num_iters-by-n orthonormal matrices and T is a size k batch of + num_iters-by-num_iters symmetric tridiagonal matrices. + """ + n = v.shape[0] + k = v.shape[1] + num_iters = int(min(n, num_iters)) + + def scan_func(loop_info, unused_x): + i, old_Q, old_diags, old_off_diags, old_v, key = loop_info + key1, key2 = jax.random.split(key) + + beta = jnp.linalg.norm(old_v, axis=0, keepdims=True) + norm = beta + + # If beta is too small, that means that we have / are close to finding a + # full basis for the Krylov subspace. We need to rejuvenate the subspace + # with a new random vector that is orthogonal to what has come so far. + old_v_too_small = jax.random.uniform( + key1, shape=old_v.shape, minval=-1.0, maxval=1.0, dtype=old_v.dtype) + old_v_too_small = reorthogonalize(old_Q, old_v_too_small) + + norm_too_small = jnp.linalg.norm(old_v_too_small, axis=0, keepdims=True) + + eps = 10 * jnp.finfo(old_Q.dtype).eps + + old_v = jnp.where(beta < eps, old_v_too_small, old_v) + norm = jnp.where(beta < eps, norm_too_small, norm) + beta = jnp.where(beta < eps, 0., beta) + + w = old_v / norm + + Aw = multiplier(w) + alpha = jnp.einsum('ij,ij->j', w, Aw) + + # Compute full reorthogonalization every time, using "twice is enough" + # reorthogonalization. + v = Aw - alpha[jnp.newaxis, :] * w + v = reorthogonalize(old_Q, v) + + diags = old_diags.at[:, i].set(alpha) + Q = old_Q.at[:, i, :].set(jnp.transpose(w)) + # Using an if here is okay, because num_iters will be statically known. + # (Either because it will be set explicitly, or because n will be + # statically known. Note that we can't use a jax.lax.cond here because + # the at/set branch will raise an error when num_iters == 1. + if num_iters > 1: + off_diags = old_off_diags.at[:, i - 1].set(beta[0]) + else: + off_diags = old_off_diags + + new_loop_info = (i + 1, Q, diags, off_diags, v, key2) + + return new_loop_info, None + + init_Q = jnp.zeros(shape=(k, num_iters, n), dtype=v.dtype) + init_diags = jnp.ones(shape=(k, num_iters), dtype=v.dtype) + init_off_diags = jnp.zeros(shape=(k, num_iters - 1), dtype=v.dtype) + # Normalize v beforehand. This is because v could have a small norm, and + # trigger rejuvenating the Krylov subspace, which we don't want. + v = v / jnp.linalg.norm(v, axis=0, keepdims=True) + initial_state = (0, init_Q, init_diags, init_off_diags, v, key) + + scan_out, _ = jax.lax.scan( + scan_func, + initial_state, + None, + num_iters) + + _, Q, diags, off_diags, _, _ = scan_out + + return Q, mbcg.SymmetricTridiagonalMatrix(diag=diags, off_diag=off_diags) + + +def make_lanczos_preconditioner( + kernel: jtf.linalg.LinearOperator, + key: jax.Array, + num_iters: int = 20): + """Return a preconditioner as a linear operator.""" + n = kernel.shape[-1] + key1, key2 = jax.random.split(key) + v = jax.random.uniform( + key1, shape=(n, 1), minval=-1.0, maxval=1.0, dtype=kernel.dtype) + Q, T = partial_lanczos(lambda x: kernel @ x, v, key2, num_iters) + + # Now diagonalize T as Q^t D Q + # TODO(thomaswc): Replace this scipy call with jnp.linalg.eigh so that + # it can be jit-ed. + evalues, evectors = scipy.linalg.eigh_tridiagonal( + T.diag[0, :], T.off_diag[0, :]) + sqrt_evalues = jnp.sqrt(evalues) + + F = jnp.einsum('i,ij->ij', sqrt_evalues, evectors @ Q[0]) + + # diag(F^t F)_i = sum_k (F^t)_{i, k} F_{k, i} = sum_k F_{k, i}^2 + diag_Ft_F = jnp.sum(F * F, axis=0) + residual_diag = jtf.linalg.diag_part(kernel) - diag_Ft_F + + eps = jnp.finfo(kernel.dtype).eps + + # TODO(srvasude): Modify this when residual_diag is near zero. This means that + # we captured the diagonal appropriately, and modifying with a shift of eps + # can alter the preconditioner greatly. + diag_linop = tfp.tf2jax.linalg.LinearOperatorDiag( + jnp.maximum(residual_diag, 0.0) + 10. * eps, is_positive_definite=True + ) + return tfp.tf2jax.linalg.LinearOperatorLowRankUpdate( + diag_linop, jnp.transpose(F), is_positive_definite=True + ) + + +def my_tridiagonal_solve( + lower_diagonal: Array, + middle_diagonal: Array, + upper_diagonal: Array, + b: Array) -> Array: + """Like jax.linalg.tridiagonal_solve, but works for all sizes.""" + m, = middle_diagonal.shape + + if m >= 3: + return jax.lax.linalg.tridiagonal_solve( + lower_diagonal, middle_diagonal, upper_diagonal, b) + + if m == 1: + return b / middle_diagonal[0] + + if m == 2: + return jnp.linalg.solve( + jnp.array([[middle_diagonal[0], upper_diagonal[0]], + [lower_diagonal[1], middle_diagonal[1]]]), + b) + + if m == 0: + return b + + raise ValueError(f'Logic error; unanticipated {m=}') + + +def tridiagonal_solve_multishift( + T: mbcg.SymmetricTridiagonalMatrix, + shifts: Array, + v_norm: Array) -> Array: + """Solve (T - shift_k I) x_k = v_norm e_1 for batch of sym. tridiagonal T. + + Args: + T: A batched SymmetricTridiagonalMatrix. T.diag should be of shape + (k, n). + shifts: A tensor of shape (s) representing s distinct scalar shifts. + v_norm: A size (k) batch of vector lengths. + + Returns: + A tensor x of shape (s, k, n) that approximately satisfies + (T[k] - shifts[i] I) x[i, j, :] = v_norm[k] e_1 + """ + n = T.diag.shape[-1] + lower_diagonal = jnp.pad(T.off_diag, ((0, 0), (1, 0))) + upper_diagonal = jnp.pad(T.off_diag, ((0, 0), (0, 1))) + target = v_norm[..., jnp.newaxis, jnp.newaxis] * jnp.eye( + n, 1, dtype=T.diag.dtype) + # Batch over the Tridiagonal matrix batch. + batch_solve = jax.vmap(my_tridiagonal_solve, in_axes=(0, 0, 0, 0)) + # Batch over the shift dimension. + multishift_batch_solve = jax.vmap(batch_solve, in_axes=(None, 0, None, None)) + solutions = multishift_batch_solve( + lower_diagonal, + T.diag - shifts[..., jnp.newaxis, jnp.newaxis], + upper_diagonal, + target) + return jnp.squeeze(solutions, axis=-1) + + +@jax.named_call +def psd_solve_multishift( + multiplier: Callable[[Array], Array], + v: Array, + shifts: Array, + key: jax.Array, + num_iters: int = 20) -> Array: + """Solve (A - shift_k I) x_k = v for PSD A. + + Args: + multiplier: A function for left matrix multiplying by an n by n + symmetric positive definite matrix A. + v: A tensor of shape (n, t) representing t n-dim vectors. + shifts: A tensor of shape (s) representing s distinct scalar shifts. + key: A random seed. + num_iters: The number of iterations to run the Lanczos tridiagonalization + algorithm for. + + Returns: + A tensor x of shape (s, t, n) that approximately satisfies + (A - shift[i] I) x[i, j, :] = v[:, j] + """ + Q, T = partial_lanczos(multiplier, v, key, num_iters) + v_norm = jnp.linalg.norm(v, axis=0, keepdims=False) + ys = tridiagonal_solve_multishift(T, shifts, v_norm) + # Q is of shape (t, num_iters, n) and ys is of shape + # (num_shifts, t, num_iters). + return jnp.einsum('tin,sti->stn', Q, ys) diff --git a/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py b/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py new file mode 100644 index 0000000000..31f58f5d37 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py @@ -0,0 +1,202 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for partial_lanczos.py.""" + +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.python.experimental.fastgp import partial_lanczos +from absl.testing import absltest + +# pylint: disable=invalid-name + + +class _PartialLanczosTest(absltest.TestCase): + def test_gram_schmidt(self): + w = jnp.ones((5, 1), dtype=self.dtype) + v = partial_lanczos.gram_schmidt( + jnp.array([[[1.0, 0, 0, 0, 0], + [0, 1.0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]], dtype=self.dtype), + w) + self.assertEqual((5, 1), v.shape) + self.assertEqual(0.0, v[0][0]) + self.assertEqual(0.0, v[1][0]) + self.assertGreater(jnp.linalg.norm(v[:, 0]), 1e-6) + + def test_partial_lanczos_identity(self): + A = jnp.identity(10).astype(self.dtype) + v = jnp.ones((10, 1)).astype(self.dtype) + Q, T = partial_lanczos.partial_lanczos( + lambda x: A @ x, v, jax.random.PRNGKey(2), 10) + np.testing.assert_allclose(jnp.identity(10), Q[0] @ Q[0].T, atol=1e-6) + np.testing.assert_allclose( + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), + 1.0, rtol=1e-5) + + def test_diagonal_matrix_heavily_imbalanced(self): + A = jnp.diag(jnp.array([ + 1e-3, 1., 2., 3., 4., 10000.], dtype=self.dtype)) + v = jnp.ones((6, 1)).astype(self.dtype) + Q, T = partial_lanczos.partial_lanczos( + lambda x: A @ x, v, jax.random.PRNGKey(9), 6) + atol = 1e-6 + det_rtol = 1e-6 + if self.dtype == np.float32: + atol = 2e-3 + det_rtol = 0.26 + np.testing.assert_allclose(jnp.identity(6), Q[0] @ Q[0].T, atol=atol) + np.testing.assert_allclose( + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), + 240., rtol=det_rtol) + + def test_partial_lanczos_full_lanczos(self): + A = jnp.array([[1.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + v = jnp.array([[-1.0], [1.0]], dtype=self.dtype) + Q, T = partial_lanczos.partial_lanczos( + lambda x: A @ x, v, jax.random.PRNGKey(3), 2) + np.testing.assert_allclose(jnp.identity(2), Q[0] @ Q[0].T, atol=1e-6) + np.testing.assert_allclose( + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), + 3.0, rtol=1e-5) + + def test_partial_lanczos_with_jit(self): + def partial_lanczos_pure_tensor(A, v): + return partial_lanczos.partial_lanczos( + lambda x: A @ x, v, jax.random.PRNGKey(4), 2) + + partial_lanczos_jit = jax.jit(partial_lanczos_pure_tensor) + + A = jnp.array([[1.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + v = jnp.array([[-1.0], [1.0]], dtype=self.dtype) + Q, T = partial_lanczos_jit(A, v) + np.testing.assert_allclose(jnp.identity(2), Q[0] @ Q[0].T, atol=1e-6) + np.testing.assert_allclose( + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), + 3.0, rtol=1e-5) + + def test_partial_lanczos_with_batching(self): + v = jnp.zeros((10, 3), dtype=self.dtype) + v = v.at[:, 0].set(jnp.ones(10)) + v = v.at[:, 1].set(jnp.array([1, -1, 1, -1, 1, -1, 1, -1, 1, -1])) + v = v.at[:, 2].set(jnp.array( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])) + Q, T = partial_lanczos.partial_lanczos( + lambda x: x, v, jax.random.PRNGKey(5), 10) + self.assertEqual(Q.shape, (3, 10, 10)) + self.assertEqual(T.diag.shape, (3, 10)) + np.testing.assert_allclose( + Q[0, 0, :], jnp.ones(10) / jnp.sqrt(10.0)) + np.testing.assert_allclose( + Q[1, 0, :], v[:, 1] / jnp.sqrt(10.0)) + + def test_make_lanczos_preconditioner(self): + kernel = jnp.identity(10).astype(self.dtype) + preconditioner = partial_lanczos.make_lanczos_preconditioner( + kernel, jax.random.PRNGKey(5)) + log_det = preconditioner.log_abs_determinant() + self.assertAlmostEqual(0.0, log_det, places=4) + out = preconditioner.solve(jnp.identity(10)) + np.testing.assert_allclose(out, jnp.identity(10), atol=9e-2) + kernel = jnp.identity(100).astype(self.dtype) + preconditioner = partial_lanczos.make_lanczos_preconditioner( + kernel, jax.random.PRNGKey(6)) + # TODO(thomaswc): Investigate ways to improve the numerical stability + # here so that log_det is closer to zero than this. + log_det = preconditioner.log_abs_determinant() + self.assertLess(jnp.abs(log_det), 10.0) + out = preconditioner.solve(jnp.identity(100)) + np.testing.assert_allclose(out, jnp.identity(100), atol=0.2) + + def test_preconditioner_preserves_psd(self): + M = jnp.array([[2.6452732, -1.4553788, -0.5272188, 0.524349], + [-1.4553788, 4.4274387, 0.21998158, 1.8666775], + [-0.5272188, 0.21998158, 2.4756536, -0.5257966], + [0.524349, 1.8666775, -0.5257966, 2.889879]]).astype( + self.dtype) + orig_eigenvalues = jnp.linalg.eigvalsh(M) + self.assertFalse((orig_eigenvalues < 0).any()) + + preconditioner = partial_lanczos.make_lanczos_preconditioner( + M, jax.random.PRNGKey(7)) + preconditioned_M = preconditioner.solve(M) + after_eigenvalues = jnp.linalg.eigvalsh(preconditioned_M) + self.assertFalse((after_eigenvalues < 0).any()) + + def test_my_tridiagonal_solve(self): + empty = jnp.array([]).astype(self.dtype) + self.assertEqual( + 0, + partial_lanczos.my_tridiagonal_solve( + empty, empty, empty, empty).size) + + np.testing.assert_allclose( + jnp.array([2.5]), + partial_lanczos.my_tridiagonal_solve( + jnp.array([0.0], dtype=self.dtype), + jnp.array([2.0], dtype=self.dtype), + jnp.array([0.0], dtype=self.dtype), + jnp.array([5.0], dtype=self.dtype))) + + np.testing.assert_allclose( + jnp.array([-4.5, 3.5]), + partial_lanczos.my_tridiagonal_solve( + jnp.array([0.0, 1.0], dtype=self.dtype), + jnp.array([2.0, 3.0], dtype=self.dtype), + jnp.array([4.0, 0.0], dtype=self.dtype), + jnp.array([5.0, 6.0], dtype=self.dtype))) + + np.testing.assert_allclose( + jnp.array([-33.0/2.0, 115.0/12.0, -11.0/6.0])[:, jnp.newaxis], + partial_lanczos.my_tridiagonal_solve( + jnp.array([0.0, 1.0, 2.0], dtype=self.dtype), + jnp.array([3.0, 4.0, 5.0], dtype=self.dtype), + jnp.array([6.0, 7.0, 0.0], dtype=self.dtype), + jnp.array([8.0, 9.0, 10.0], dtype=self.dtype)[:, jnp.newaxis]), + atol=1e-6 + ) + + def test_psd_solve_multishift(self): + v = jnp.array([1.0, 1.0, 1.0, 1.0], dtype=self.dtype) + solutions = partial_lanczos.psd_solve_multishift( + lambda x: x, + v[:, jnp.newaxis], + jnp.array([0.0, 2.0, -1.0], dtype=self.dtype), + jax.random.PRNGKey(8)) + np.testing.assert_allclose( + solutions[:, 0, :], + [[1.0, 1.0, 1.0, 1.0], + [-1.0, -1.0, -1.0, -1.0], + [0.5, 0.5, 0.5, 0.5]]) + + +class PartialLanczosTestFloat32(_PartialLanczosTest): + dtype = np.float32 + + +class PartialLanczosTestFloat64(_PartialLanczosTest): + dtype = np.float64 + + +del _PartialLanczosTest + + +if __name__ == "__main__": + config.update("jax_enable_x64", True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/preconditioners.py b/tensorflow_probability/python/experimental/fastgp/preconditioners.py new file mode 100644 index 0000000000..d891cd9531 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/preconditioners.py @@ -0,0 +1,734 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Library of popular matrix preconditioners. + +Given a matrix M, a preconditioner P of M is an easy to compute approximation +of M with the following properties: + + 1) Pv and P^(-1)v are easy to compute, preferably in time closer to O(n) + than O(n^2), + 2) det P is easy to compute, and + 3) one of M P^(-1), P^(-1) M, or A^(-1) M B^(-1) (with P = A B) is closer + to the identity than M is. The specific distance metric we are most often + interested in is condition number, and the three options are referred to + as "right", "left" or "split" preconditioning respectively. + +For more details, see chapter 9 of "Iterative Methods for Sparse Linear +Systems", available online at +https://www-users.cse.umn.edu/~saad/IterMethBook_2ndEd.pdf + +The preconditioners here are intended to be used to improve the convergence +of the log det and yt_inv_y operations needed by Gaussian Processes. The +preconditioner class you will want to use will depend on the GP kernel type. +https://arxiv.org/pdf/2107.00243.pdf suggests that partial Cholesky or QFF +work well for RBF kernels, and that RFF or truncated SVD might be good when +nothing is known about the kernel structure. +""" + +import jax +import jax.experimental.sparse +import jax.numpy as jnp +from jaxtyping import Float +from tensorflow_probability.python.experimental.fastgp import linalg +from tensorflow_probability.python.experimental.fastgp import linear_operator_sum +from tensorflow_probability.substrates import jax as tfp + +jtf = tfp.tf2jax + +# pylint: disable=invalid-name + + +@jax.named_call +def promote_to_operator(M) -> jtf.linalg.LinearOperator: + if isinstance(M, jtf.linalg.LinearOperator): + return M + return jtf.linalg.LinearOperatorFullMatrix(M, is_non_singular=True) + + +def _diag_part(M) -> jax.Array: + if isinstance(M, jtf.linalg.LinearOperator): + return M.diag_part() + return jtf.linalg.diag_part(M) + + +class Preconditioner: + """Base class for preconditioners.""" + + def __init__(self, M: jtf.linalg.LinearOperator): + self.M = M + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + """Returns the preconditioner.""" + raise NotImplementedError('Base classes must override full_preconditioner.') + + def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + """Returns the combined action of M and the preconditioner.""" + raise NotImplementedError( + 'Base classes must override preconditioned_operator.') + + def log_det(self) -> jtf.linalg.LinearOperator: + """The log absolute value of the determinant of the preconditioner.""" + return self.full_preconditioner().log_abs_determinant() + + def trace_of_inverse_product(self, A: jax.Array) -> Float: + """Returns tr( P^(-1) A ) for a n x n, non-batched A.""" + result = self.full_preconditioner().solve(A) + if isinstance(result, jtf.linalg.LinearOperator): + return result.trace() + return jnp.trace(result) + + +@jax.tree_util.register_pytree_node_class +class IdentityPreconditioner(Preconditioner): + """The do-nothing preconditioner.""" + + def __init__(self, M: jtf.linalg.LinearOperator, **unused_kwargs): + n = M.shape[-1] + self.id = jtf.linalg.LinearOperatorIdentity(n, dtype=M.dtype) + super().__init__(M) + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + return self.id + + def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + return promote_to_operator(self.M) + + def log_det(self) -> Float: + return 0.0 + + def trace_of_inverse_product(self, A: jax.Array) -> Float: + return jnp.trace(A) + + def tree_flatten(self): + return ((self.M,), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls(*children) + + +@jax.tree_util.register_pytree_node_class +class DiagonalPreconditioner(Preconditioner): + """The best diagonal preconditioner; aka the Jacobi preconditioner.""" + + def __init__(self, M: jtf.linalg.LinearOperator, **unused_kwargs): + self.d = jnp.maximum(_diag_part(M), 1e-6) + super().__init__(M) + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + return jtf.linalg.LinearOperatorDiag( + self.d, is_non_singular=True, is_positive_definite=True + ) + + def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + return jtf.linalg.LinearOperatorComposition( + [promote_to_operator(self.M), self.full_preconditioner().inverse()] + ) + + def log_det(self) -> Float: + return jnp.sum(jnp.log(self.d)) + + def trace_of_inverse_product(self, A: jax.Array) -> Float: + return jnp.sum(jnp.diag(A) / self.d) + + def tree_flatten(self): + return ((self.M,), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls(*children) + + +@jax.tree_util.register_pytree_node_class +class LowRankPreconditioner(Preconditioner): + """Turns M ~ A A^t for low rank A into a preconditioner.""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + low_rank: jax.Array, + residual_diag: jax.Array = None, + ): + self.low_rank = low_rank + n, self.r = self.low_rank.shape + assert n == M.shape[-1], ( + f'Low Rank has shape {self.low_rank.shape}; should have shape' + f' ({M.shape[-1]}, r)' + ) + + if residual_diag is None: + self.residual_diag = _diag_part(M) - jnp.einsum( + 'ij,ij->i', self.low_rank, self.low_rank + ) + else: + self.residual_diag = residual_diag + + self.residual_diag = jnp.maximum(1e-6, self.residual_diag) + + diag_op = jtf.linalg.LinearOperatorDiag( + self.residual_diag, is_non_singular=True, is_positive_definite=True) + self.pre = jtf.linalg.LinearOperatorLowRankUpdate( + diag_op, self.low_rank, is_positive_definite=True) + super().__init__(M) + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + return self.pre + + def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + return jtf.linalg.LinearOperatorComposition( + [promote_to_operator(self.M), self.pre.inverse()] + ) + + @classmethod + def from_lowrank(cls, M, low_rank): + """Alternate constructor when low_rank is already made.""" + x = LowRankPreconditioner(M, low_rank) + x.__class__ = cls + return x + + def tree_flatten(self): + return ((self.M, self.low_rank), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls.from_lowrank(*children) + + +@jax.tree_util.register_pytree_node_class +class RankOnePreconditioner(LowRankPreconditioner): + """Preconditioner based on M ~ v v^t using M's largest eigenvector v.""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + num_iters: int = 10, + **unused_kwargs, + ): + evalue, evector = linalg.largest_eigenvector(M, key, num_iters) + v = jnp.sqrt(evalue) * evector + low_rank = v[:, jnp.newaxis] + super().__init__(M, low_rank) + + +@jax.tree_util.register_pytree_node_class +class PartialCholeskyPreconditioner(LowRankPreconditioner): + """https://en.wikipedia.org/wiki/Incomplete_Cholesky_factorization .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + rank: int = 20, + **unused_kwargs, + ): + n = M.shape[-1] + rank = min(n, rank) + low_rank, _, residual_diag = tfp.math.low_rank_cholesky(M, rank) + super().__init__(M, low_rank, residual_diag) + + +@jax.tree_util.register_pytree_node_class +class PartialLanczosPreconditioner(LowRankPreconditioner): + """https://www.sciencedirect.com/science/article/pii/S0307904X13002382 .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + rank: int = 20, + **unused_kwargs, + ): + low_rank = linalg.make_partial_lanczos(key, M, rank) + super().__init__(M, low_rank) + + +@jax.tree_util.register_pytree_node_class +class TruncatedSvdPreconditioner(LowRankPreconditioner): + """https://www.math.kent.edu/~reichel/publications/tsvd.pdf . + + Note that 5 * num_iters must be less than n. + """ + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + rank: int = 20, + num_iters: int = 10, + **unused_kwargs, + ): + low_rank = linalg.make_truncated_svd(key, M, rank, num_iters) + super().__init__(M, low_rank) + + +@jax.tree_util.register_pytree_node_class +class TruncatedRandomizedSvdPreconditioner(LowRankPreconditioner): + """https://www.math.kent.edu/~reichel/publications/tsvd.pdf . + + Note that 5 * num_iters must be less than n. + """ + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + rank: int = 20, + num_iters: int = 10, + **unused_kwargs, + ): + low_rank = linalg.make_randomized_truncated_svd(key, M, rank, num_iters) + super().__init__(M, low_rank) + + +@jax.tree_util.register_pytree_node_class +class LowRankPlusScalingPreconditioner(Preconditioner): + """Turns M ~ a * I + A A^t for low rank A into a preconditioner.""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + low_rank: jax.Array, + scaling: jax.Array, + ): + self.low_rank = low_rank + n, self.r = self.low_rank.shape + assert n == M.shape[-1], ( + f'Low Rank has shape {self.low_rank.shape}; should have shape' + f' ({M.shape[-1]}, r)' + ) + self.scaling = scaling + identity_op = jtf.linalg.LinearOperatorScaledIdentity( + num_rows=M.shape[-1], + multiplier=self.scaling, + is_non_singular=True, + is_positive_definite=True) + self.pre = jtf.linalg.LinearOperatorLowRankUpdate( + identity_op, + self.low_rank, + is_positive_definite=True, + is_self_adjoint=True, + is_non_singular=True) + super().__init__(M) + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + return self.pre + + def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + linop = promote_to_operator(self.M) + operator = linear_operator_sum.LinearOperatorSum( + [linop, + jtf.linalg.LinearOperatorScaledIdentity( + num_rows=self.M.shape[-1], + multiplier=self.scaling)]) + return jtf.linalg.LinearOperatorComposition([self.pre.inverse(), operator]) + + @classmethod + def from_lowrank(cls, M, low_rank, scaling): + """Alternate constructor when low_rank is already made.""" + x = LowRankPlusScalingPreconditioner(M, low_rank, scaling) + x.__class__ = cls + return x + + def tree_flatten(self): + return ((self.M, self.low_rank, self.scaling), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls.from_lowrank(*children) + + +@jax.tree_util.register_pytree_node_class +class PartialCholeskyPlusScalingPreconditioner( + LowRankPlusScalingPreconditioner): + """https://en.wikipedia.org/wiki/Incomplete_Cholesky_factorization .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + scaling: jax.Array, + rank: int = 20, + **unused_kwargs, + ): + n = M.shape[-1] + rank = min(n, rank) + low_rank, _, _ = tfp.math.low_rank_cholesky(M, rank) + super().__init__(M, low_rank, scaling) + + +@jax.tree_util.register_pytree_node_class +class PartialPivotedCholeskyPlusScalingPreconditioner( + LowRankPlusScalingPreconditioner): + """https://en.wikipedia.org/wiki/Incomplete_Cholesky_factorization .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + scaling: jax.Array, + rank: int = 20, + **unused_kwargs, + ): + n = M.shape[-1] + rank = min(n, rank) + low_rank = linalg.make_partial_pivoted_cholesky(M, rank) + super().__init__(M, low_rank, scaling) + + +@jax.tree_util.register_pytree_node_class +class PartialLanczosPlusScalingPreconditioner( + LowRankPlusScalingPreconditioner): + """https://www.sciencedirect.com/science/article/pii/S0307904X13002382 .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + scaling: jax.Array, + key: jax.Array, + rank: int = 20, + **unused_kwargs, + ): + low_rank = linalg.make_partial_lanczos(key, M, rank) + super().__init__(M, low_rank, scaling) + + +@jax.tree_util.register_pytree_node_class +class TruncatedSvdPlusScalingPreconditioner(LowRankPlusScalingPreconditioner): + """https://www.math.kent.edu/~reichel/publications/tsvd.pdf . + + Note that 5 * num_iters must be less than n. + """ + + def __init__( + self, + M: jtf.linalg.LinearOperator, + scaling: jax.Array, + key: jax.Array, + rank: int = 20, + num_iters: int = 10, + **unused_kwargs, + ): + low_rank = linalg.make_truncated_svd(key, M, rank, num_iters) + super().__init__(M, low_rank, scaling) + + +@jax.tree_util.register_pytree_node_class +class TruncatedRandomizedSvdPlusScalingPreconditioner( + LowRankPlusScalingPreconditioner): + """https://www.math.kent.edu/~reichel/publications/tsvd.pdf . + + Note that 5 * num_iters must be less than n. + """ + + def __init__( + self, + M: jtf.linalg.LinearOperator, + scaling: jax.Array, + key: jax.Array, + rank: int = 20, + num_iters: int = 10, + **unused_kwargs, + ): + low_rank = linalg.make_randomized_truncated_svd(key, M, rank, num_iters) + super().__init__(M, low_rank, scaling) + + +class SplitPreconditioner(Preconditioner): + """Base class for symmetric split preconditioners.""" + + # pylint: disable-next=useless-parent-delegation + def __init__(self, M: jtf.linalg.LinearOperator): + super().__init__(M) + + def right_half(self) -> jtf.linalg.LinearOperator: + """Returns R, where the preconditioner is P = R^T R.""" + raise NotImplementedError('Base classes must override right_half method.') + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + """Returns P = R^T R, the preconditioner's approximation to M.""" + rh = self.right_half() + lh = rh.adjoint() + return jtf.linalg.LinearOperatorComposition( + [lh, rh], + is_self_adjoint=True, + is_positive_definite=True, + ) + + def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + """Returns R^(-T) M R^(-1).""" + rhi = self.right_half().inverse() + lhi = rhi.adjoint() + return jtf.linalg.LinearOperatorComposition( + [lhi, promote_to_operator(self.M), rhi], + is_self_adjoint=True, + is_positive_definite=True, + ) + + def log_det(self) -> Float: + """Returns log det(R^T R) = 2 log det R.""" + return 2 * self.right_half().log_abs_determinant() + + def trace_of_inverse_product(self, A: jax.Array) -> Float: + """Returns tr( (R^T R)^(-1) A ) for a n x n, non-batched A.""" + raise NotImplementedError( + 'Base classes must override trace_of_inverse_product.') + + +@jax.tree_util.register_pytree_node_class +class DiagonalSplitPreconditioner(SplitPreconditioner): + """The split conditioner which pre and post multiplies by a diagonal.""" + + def __init__(self, M: jtf.linalg.LinearOperator, **unused_kwargs): + self.d = jnp.maximum(_diag_part(M), 1e-6) + self.sqrt_d = jnp.sqrt(self.d) + super().__init__(M) + + def right_half(self) -> jtf.linalg.LinearOperator: + return jtf.linalg.LinearOperatorDiag( + self.sqrt_d, is_non_singular=True, is_positive_definite=True + ) + + def full_preconditioner(self) -> jtf.linalg.LinearOperator: + return jtf.linalg.LinearOperatorDiag( + self.d, is_non_singular=True, is_positive_definite=True + ) + + def log_det(self) -> Float: + return jnp.sum(jnp.log(self.d)) + + def trace_of_inverse_product(self, A: jax.Array) -> Float: + return jnp.sum(jnp.diag(A) / self.d) + + def tree_flatten(self): + return ((self.M,), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls(*children) + + +@jax.tree_util.register_pytree_node_class +class LowRankSplitPreconditioner(SplitPreconditioner): + """Turns M ~ A A^t for low rank A into a split preconditioner.""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + low_rank: jax.Array, + residual_diag: jax.Array = None, + ): + self.low_rank = low_rank + n, self.r = self.low_rank.shape + assert n == M.shape[-1], ( + f'Low Rank has shape {self.low_rank.shape}; should have shape' + f' ({M.shape[-1]}, r)' + ) + + if residual_diag is None: + self.residual_diag = _diag_part(M) - jnp.einsum( + 'ij,ij->i', self.low_rank, self.low_rank + ) + else: + self.residual_diag = residual_diag + + self.residual_diag = jnp.maximum(1e-6, self.residual_diag) + + # self.low_rank isn't invertible, so we need to make it part of a block + # matrix that is. self.low_rank will be the first r columns of that + # matrix; the other n-r columns will be zero for the first r rows and + # a diagonal matrix in the remaining (n-r, n-r) bottom right block. + # TODO(thomaswc): Iteratively add a small constant to the diagonal of B if + # it is singular. + # TODO(thomaswc): Permute the indices so that the lowest r values of + # self.residual_diag are in the first r entries (so that the diagonal + # matrix in the bottom right block can kill the largest n-r residual_diag + # values). + # Turn off Pyformat because it puts spaces between a slice bound and its + # colon. + # fmt: off + self.B = jtf.linalg.LinearOperatorFullMatrix( + self.low_rank[:self.r, :], is_non_singular=True + ) + self.C = jtf.linalg.LinearOperatorFullMatrix(self.low_rank[self.r:, :]) + sqrt_d = jnp.sqrt(self.residual_diag[self.r:]) + # fmt: on + self.D = jtf.linalg.LinearOperatorDiag(sqrt_d, is_non_singular=True) + P = tfp.tf2jax.linalg.LinearOperatorBlockLowerTriangular( + [[self.B], [self.C, self.D]], is_non_singular=True + ) + # We started from M ~ low_rank low_rank^t (because low_rank is n by r), + # & we need the preconditioner P to satisfy M ~ P^t P, so P = low_rank^t. + self.P = P.adjoint() + + super().__init__(M) + + def right_half(self) -> jtf.linalg.LinearOperator: + return self.P + + def trace_of_inverse_product(self, A: jax.Array) -> Float: + # We want the trace of (P^T P)^(-1) A + # = P^(-1) P^(-t) A + # = [[ B^(-1), - B^(-1) C D^(-1)], [0, D^(-1)]] + # @ [[ B^(-t), 0], [- D^(-1) C^t B^(-t), D^(-1)]] + # @ [[ A11, A12 ], [ A21, A22 ]] + # = [[ B^(-1) B^(-t) A11 - B^(-1) C D^(-2) (A21 - C^t B^(-t) A11), *], + # [*, D^(-2) ( A22 - C^t B^(-t) A12 ) ]] + # (But actually, because of the self.P = P.adjoint at the end of + # __init__, B is actually always B^t in the above and C is always + # C^t. + n = A.shape[-1] + if self.r == n: + return jnp.trace(self.full_preconditioner().solvevec(A)) + A11 = A[:self.r, :self.r] + A12 = A[:self.r, self.r:] + A21 = A[self.r:, :self.r] + A22 = A[self.r:, self.r:] + D2 = self.residual_diag[self.r:] + # TODO(thomaswc): Compute the LU decomposition of B, and use that in + # place of all of the self.B.solvevec's. + Binvt_A11 = self.B.solvevec(A11) + first_term = jnp.trace(self.B.H.solvevec(Binvt_A11)) + inner_factor = (A21 - self.C @ Binvt_A11) / D2[:, jnp.newaxis] + second_term = jnp.trace(self.B.H.solvevec(self.C.H @ inner_factor)) + A22_term = jnp.sum(jnp.diag(A22) / D2) + Binvt_A12 = self.B.solve(A12) + diag_Ct_Binvt_A12 = jnp.einsum('ij,ji->i', self.C.to_dense(), Binvt_A12) + A12_term = jnp.sum(diag_Ct_Binvt_A12 / D2) + return first_term - second_term + A22_term - A12_term + + @classmethod + def from_lowrank(cls, M, low_rank): + """Alternate constructor when low_rank is already made.""" + x = LowRankSplitPreconditioner(M, low_rank) + x.__class__ = cls + return x + + def tree_flatten(self): + return ((self.M, self.low_rank), None) + + @classmethod + def tree_unflatten(cls, unused_aux_data, children): + return cls.from_lowrank(*children) + + +@jax.tree_util.register_pytree_node_class +class RankOneSplitPreconditioner(LowRankSplitPreconditioner): + """Split preconditioner based on M ~ v v^t using M's largest eigenvector v.""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + num_iters: int = 10, + **unused_kwargs, + ): + evalue, evector = linalg.largest_eigenvector(M, key, num_iters) + v = jnp.sqrt(evalue) * evector + low_rank = v[:, jnp.newaxis] + super().__init__(M, low_rank) + + +@jax.tree_util.register_pytree_node_class +class PartialCholeskySplitPreconditioner(LowRankSplitPreconditioner): + """https://en.wikipedia.org/wiki/Incomplete_Cholesky_factorization .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + rank: int = 20, + **unused_kwargs, + ): + n = M.shape[-1] + rank = min(n, rank) + low_rank, _, residual_diag = tfp.math.low_rank_cholesky(M, rank) + super().__init__(M, low_rank, residual_diag) + + +@jax.tree_util.register_pytree_node_class +class PartialLanczosSplitPreconditioner(LowRankSplitPreconditioner): + """https://www.sciencedirect.com/science/article/pii/S0307904X13002382 .""" + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + rank: int = 20, + **unused_kwargs, + ): + low_rank = linalg.make_partial_lanczos(key, M, rank) + super().__init__(M, low_rank) + + +@jax.tree_util.register_pytree_node_class +class TruncatedSvdSplitPreconditioner(LowRankSplitPreconditioner): + """https://www.math.kent.edu/~reichel/publications/tsvd.pdf . + + Note that 5 * num_iters must be less than n. + """ + + def __init__( + self, + M: jtf.linalg.LinearOperator, + key: jax.Array, + rank: int = 20, + num_iters: int = 10, + **unused_kwargs, + ): + low_rank = linalg.make_truncated_svd(key, M, rank, num_iters) + super().__init__(M, low_rank) + + +# TODO(thomaswc): RFF and QFF preconditioners. + + +PRECONDITIONER_REGISTRY = { + 'identity': IdentityPreconditioner, + 'diagonal': DiagonalPreconditioner, + 'rank_one': RankOnePreconditioner, + 'partial_cholesky': PartialCholeskyPreconditioner, + 'partial_lanczos': PartialLanczosPreconditioner, + 'truncated_svd': TruncatedSvdPreconditioner, + 'truncated_randomized_svd': TruncatedRandomizedSvdPreconditioner, + 'diagonal_split': DiagonalSplitPreconditioner, + 'rank_one_split': RankOneSplitPreconditioner, + 'partial_cholesky_split': PartialCholeskySplitPreconditioner, + 'partial_lanczos_split': PartialLanczosSplitPreconditioner, + 'truncated_svd_split': TruncatedSvdSplitPreconditioner, + 'partial_pivoted_cholesky_plus_scaling': ( + PartialPivotedCholeskyPlusScalingPreconditioner), + 'partial_cholesky_plus_scaling': PartialCholeskyPlusScalingPreconditioner, + 'partial_lanczos_plus_scaling': PartialLanczosPlusScalingPreconditioner, + 'truncated_svd_plus_scaling': TruncatedSvdPlusScalingPreconditioner, + 'truncated_randomized_svd_plus_scaling': ( + TruncatedRandomizedSvdPlusScalingPreconditioner), +} + + +@jax.named_call +def get_preconditioner( + preconditioner_name: str, M: jtf.linalg.LinearOperator, **kwargs +) -> SplitPreconditioner: + """Return the preconditioner of the given type for the given matrix.""" + if preconditioner_name == 'auto': + n = M.shape[-1] + rank = kwargs.get('rank', 20) + if 5 * rank >= n: + preconditioner_name = 'partial_cholesky_split' + else: + preconditioner_name = 'truncated_svd' + try: + return PRECONDITIONER_REGISTRY[preconditioner_name](M, **kwargs) + except KeyError as key_error: + raise ValueError( + 'Unknown preconditioner name {}, known preconditioners are {}'.format( + preconditioner_name, PRECONDITIONER_REGISTRY.keys())) from key_error diff --git a/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py b/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py new file mode 100644 index 0000000000..a0e0c54bd4 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py @@ -0,0 +1,701 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test for preconditioners.py.""" + +from absl.testing import parameterized +import jax +from jax import config +import jax.numpy as jnp +import numpy as np +from tensorflow_probability.python.experimental.fastgp import preconditioners +import tensorflow_probability.substrates.jax as tfp +from absl.testing import absltest + +jtf = tfp.tf2jax + + +# pylint: disable=invalid-name + + +class _PreconditionersTest(parameterized.TestCase): + + def test_identity_preconditioner(self): + idp = preconditioners.IdentityPreconditioner( + jnp.identity(3).astype(self.dtype)) + x = jnp.array([1.0, 2.0, 3.0], dtype=self.dtype) + np.testing.assert_allclose(x, idp.full_preconditioner().solvevec(x)) + np.testing.assert_allclose(x, idp.preconditioned_operator().matvec(x)) + self.assertAlmostEqual(0.0, idp.log_det()) + + def test_diagonal_preconditioner(self): + dp = preconditioners.DiagonalPreconditioner( + jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + ) + self.assertAlmostEqual(jnp.log(4.0), dp.log_det(), places=5) + np.testing.assert_allclose( + jnp.array([1.0, 0.25]), + dp.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-6, + ) + np.testing.assert_allclose( + jnp.array([1.0, 4.0]), + dp.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-6, + ) + + def test_diagonal_split_preconditioner(self): + dp = preconditioners.DiagonalSplitPreconditioner( + jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype) + ) + self.assertAlmostEqual(jnp.log(4.0), dp.log_det(), places=5) + np.testing.assert_allclose( + jnp.array([1.0, 0.25]), + dp.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-6, + ) + np.testing.assert_allclose( + jnp.array([1.0, 4.0]), + dp.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-6, + ) + + def test_rank_one_preconditioner(self): + r1p = preconditioners.RankOnePreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype), + jax.random.PRNGKey(0)) + # The true rank one approximation to M here has v ~ [1.30, 2.02], + # which leads to the full matrix ~ [[1.3, 0.0], [2.02, 0.28]] + # which has det ~ 0.37 and log det ~ -1. Which then implies that + # the full preconditioner has log det ~ -2. + self.assertAlmostEqual(1.7164037, r1p.log_det(), delta=4) + np.testing.assert_allclose( + jnp.array([0.4, 0.1]), + r1p.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1, + ) + np.testing.assert_allclose( + jnp.array([3.5, 5.6]), + r1p.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + + def test_partial_cholesky_preconditioner(self): + pcp = preconditioners.PartialCholeskyPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + ) + self.assertAlmostEqual(jnp.log(7.0), pcp.log_det(), places=5) + atol = 1e-6 + if self.dtype == np.float32: + atol = 2e-2 + + np.testing.assert_allclose( + jnp.array([3/7., 1/7.]), + pcp.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=atol, + ) + np.testing.assert_allclose( + jnp.array([3.0, 5.0]), + pcp.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=atol, + ) + + def test_partial_lanczos_preconditioner(self): + plp = preconditioners.PartialLanczosPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype), + jax.random.PRNGKey(5) + ) + self.assertAlmostEqual(jnp.log(7.0), plp.log_det(), places=5) + atol = 1e-6 + if self.dtype == np.float32: + atol = 3e-1 + np.testing.assert_allclose( + jnp.array([3/7., 1/7.]), + plp.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=atol, + ) + np.testing.assert_allclose( + jnp.array([3.0, 5.0]), + plp.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=atol, + ) + + def test_truncated_svd_preconditioner(self): + tsvd = preconditioners.TruncatedSvdPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype), + jax.random.PRNGKey(1) + ) + self.assertAlmostEqual(jnp.log(7.0), tsvd.log_det(), delta=0.2) + np.testing.assert_allclose( + jnp.array([0.5, 0.25]), + tsvd.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + np.testing.assert_allclose( + jnp.array([2.0, 4.0]), + tsvd.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + + def test_rank_one_split_preconditioner(self): + r1p = preconditioners.RankOneSplitPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype), + jax.random.PRNGKey(0) + ) + # The true rank one approximation to M here has v ~ [1.30, 2.02], + # which leads to the full matrix ~ [[1.3, 0.0], [2.02, 0.28]] + # which has det ~ 0.37 and log det ~ -1. Which then implies that + # the full preconditioner has log det ~ -2. + self.assertAlmostEqual(1.7164037, r1p.log_det(), delta=4) + np.testing.assert_allclose( + jnp.array([16, -6]), + r1p.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1, + ) + np.testing.assert_allclose( + jnp.array([2.2, 5.6]), + r1p.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + + def test_partial_cholesky_split_preconditioner(self): + pcp = preconditioners.PartialCholeskySplitPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + ) + self.assertAlmostEqual(jnp.log(7.0), pcp.log_det(), places=5) + np.testing.assert_allclose( + jnp.array([3/7., 1/7.]), + pcp.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + np.testing.assert_allclose( + jnp.array([3.0, 5.0]), + pcp.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + + def test_partial_lanczos_split_preconditioner(self): + plp = preconditioners.PartialLanczosSplitPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype), + jax.random.PRNGKey(1) + ) + self.assertAlmostEqual(jnp.log(7.0), plp.log_det(), places=5) + np.testing.assert_allclose( + jnp.array([3/7., 1/7.]), + plp.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=3e-6, + ) + np.testing.assert_allclose( + jnp.array([3.0, 5.0]), + plp.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=3e-6, + ) + + def test_truncated_svd_split_preconditioner(self): + tsvd = preconditioners.TruncatedSvdSplitPreconditioner( + jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype), + jax.random.PRNGKey(1) + ) + self.assertAlmostEqual(jnp.log(7.0), tsvd.log_det(), delta=0.2) + np.testing.assert_allclose( + jnp.array([0.5, 0.25]), + tsvd.full_preconditioner().solvevec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + np.testing.assert_allclose( + jnp.array([2.0, 4.0]), + tsvd.full_preconditioner().matvec( + jnp.array([1.0, 1.0], dtype=self.dtype)), + atol=1e-1, + ) + + @parameterized.parameters( + ("auto", 1.2e-7, 1e-6), + ("identity", 2.0, 1.0), + ("diagonal", 0.2, 0.2), + ("rank_one", 0.3, 0.1), + ("partial_cholesky", 1e-6, 0.02), + ("partial_lanczos", 1e-6, 0.1), + ("truncated_svd", 0.2, 0.2), + ("truncated_randomized_svd", 0.2, 0.2), + ("diagonal_split", 0.2, 0.2), + ("rank_one_split", 4.0, 20.0), + ("partial_cholesky_split", 1.2e-7, 1e-6), + ("partial_lanczos_split", 4e-7, 1e-6), + ("truncated_svd_split", 0.2, 0.2), + ) + def test_get_preconditioner(self, preconditioner, log_det_delta, solve_atol): + m = jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + v = jnp.array([1.0, 1.0], dtype=self.dtype) + + p = preconditioners.get_preconditioner( + preconditioner, m, key=jax.random.PRNGKey(3)) + self.assertAlmostEqual(jnp.log(7.0), p.log_det(), delta=log_det_delta) + np.testing.assert_allclose( + jnp.array([0.42857143, 0.14285714]), + p.full_preconditioner().solvevec(v), + atol=solve_atol + ) + + @parameterized.parameters( + ("partial_cholesky_plus_scaling", 4e-7, 1e-7), + ("partial_lanczos_plus_scaling", 3e-7, 2e-7), + ("partial_pivoted_cholesky_plus_scaling", 2e-7, 2e-7), + # TODO(srvasude): Test this on larger matrices, since + # the low rank decomposition for small matrices is just a zero matrix. + # ("truncated_svd_plus_scaling", 3e-7, 1e-7), + ) + def test_get_preconditioner_with_identity( + self, preconditioner, log_det_delta, solve_atol): + m = jnp.array([[1.0, 1.0], [1.0, 3.0]], dtype=self.dtype) + v = jnp.array([1.0, 1.0], dtype=self.dtype) + + p = preconditioners.get_preconditioner( + preconditioner, + m, + key=jax.random.PRNGKey(3), + scaling=self.dtype(1.)) + self.assertAlmostEqual(jnp.log(7.0), p.log_det(), delta=log_det_delta) + np.testing.assert_allclose( + jnp.array([0.42857143, 0.14285714]), + p.full_preconditioner().solvevec(v), + atol=solve_atol + ) + + @parameterized.parameters( + ("partial_cholesky_plus_scaling", 0.5, 0.3), + ("partial_lanczos_plus_scaling", 1.3, 0.9), + ("partial_pivoted_cholesky_plus_scaling", 0.5, 0.3), + ("truncated_randomized_svd_plus_scaling", 0.5, 0.3), + # TODO(srvasude): Test this on larger matrices, since + # the low rank decomposition for small matrices is just a zero matrix. + # ("truncated_svd_plus_scaling", 1.2e-7, 1e-7), + ) + def test_get_preconditioner_with_scaling_kwargs( + self, preconditioner, log_det_delta, solve_atol): + m = jnp.array([[1.0, 1.0], [1.0, 3.0]], dtype=self.dtype) + v = jnp.array([1.0, 1.0], dtype=self.dtype) + + p = preconditioners.get_preconditioner( + preconditioner, + m, + rank=1, + num_iters=5, + scaling=self.dtype(1.), + key=jax.random.PRNGKey(4) + ) + self.assertAlmostEqual(jnp.log(7.0), p.log_det(), delta=log_det_delta) + np.testing.assert_allclose( + jnp.array([0.42857143, 0.14285714]), + p.full_preconditioner().solvevec(v), + atol=solve_atol + ) + + @parameterized.parameters( + ("identity", 2.0, 1.0), + ("diagonal", 0.2, 0.2), + ("rank_one", 0.3, 0.1), + ("partial_cholesky", 1e-6, 0.2), + ("partial_lanczos", 0.6, 0.3), + ("truncated_svd", 0.2, 0.2), + ("truncated_randomized_svd", 0.3, 0.3), + ("diagonal_split", 0.2, 0.2), + ("rank_one_split", 4.0, 16.0), + # partial_cholesky_split is basically broken until index permutations + # are added to LowRankSplitPreconditioner. + # ("partial_cholesky_split", 20.0, 0), + ("partial_lanczos_split", 0.9, 1.3), + ("truncated_svd_split", 0.2, 0.2), + ) + def test_get_preconditioner_with_kwargs( + self, preconditioner, log_det_delta, solve_atol): + m = jnp.array([[2.0, 1.0], [1.0, 4.0]], dtype=self.dtype) + v = jnp.array([1.0, 1.0], dtype=self.dtype) + + p = preconditioners.get_preconditioner( + preconditioner, m, rank=1, num_iters=5, key=jax.random.PRNGKey(4) + ) + self.assertAlmostEqual(jnp.log(7.0), p.log_det(), delta=log_det_delta) + np.testing.assert_allclose( + jnp.array([0.42857143, 0.14285714]), + p.full_preconditioner().solvevec(v), + atol=solve_atol + ) + + @parameterized.parameters( + ("identity", 9000), + ("diagonal", 9000), + ("rank_one", 3100), + ("partial_cholesky", 5.0), + ("partial_lanczos", 9000.), + ("truncated_svd", 3100.0), + ("truncated_randomized_svd", 3100.0), + ("diagonal_split", 9000.), + ("rank_one_split", 3100), + # partial_cholesky_split is basically broken until index permutations + # are added to LowRankSplitPreconditioner. + # ("partial_cholesky_split", 2.0), + ("partial_lanczos_split", 2e9), + ("truncated_svd_split", 3100.0), + ) + def test_post_conditioned(self, preconditioner, condition_number_bound): + M = jnp.array([ + [ + 1.001, + 0.88311934, + 0.9894911, + 0.9695768, + 0.9987461, + 0.98577714, + 0.97863793, + 0.9880289, + 0.7110599, + 0.7718459, + ], + [ + 0.88311934, + 1.001, + 0.9395206, + 0.7564426, + 0.86025584, + 0.94721663, + 0.7791884, + 0.8075757, + 0.9478641, + 0.9758552, + ], + [ + 0.9894911, + 0.9395206, + 1.001, + 0.92534095, + 0.98108065, + 0.9997143, + 0.93953925, + 0.95583755, + 0.79332554, + 0.84795874, + ], + [ + 0.9695768, + 0.7564426, + 0.92534095, + 1.001, + 0.98049456, + 0.91640615, + 0.9991695, + 0.99564964, + 0.5614807, + 0.6257758, + ], + [ + 0.9987461, + 0.86025584, + 0.98108065, + 0.98049456, + 1.001, + 0.97622854, + 0.98763895, + 0.99449164, + 0.6813891, + 0.74358207, + ], + [ + 0.98577714, + 0.94721663, + 0.9997143, + 0.91640615, + 0.97622854, + 1.001, + 0.9313745, + 0.9487237, + 0.80610526, + 0.859435, + ], + [ + 0.97863793, + 0.7791884, + 0.93953925, + 0.9991695, + 0.98763895, + 0.9313745, + 1.001, + 0.99861676, + 0.5861309, + 0.65042824, + ], + [ + 0.9880289, + 0.8075757, + 0.95583755, + 0.99564964, + 0.99449164, + 0.9487237, + 0.99861676, + 1.001, + 0.61803514, + 0.68201244, + ], + [ + 0.7110599, + 0.9478641, + 0.79332554, + 0.5614807, + 0.6813891, + 0.80610526, + 0.5861309, + 0.61803514, + 1.001, + 0.9943819, + ], + [ + 0.7718459, + 0.9758552, + 0.84795874, + 0.6257758, + 0.74358207, + 0.859435, + 0.65042824, + 0.68201244, + 0.9943819, + 1.001, + ], + ], dtype=self.dtype) + + pc = preconditioners.get_preconditioner( + preconditioner, M, rank=5, key=jax.random.PRNGKey(5) + ) + post_conditioned = pc.preconditioned_operator().to_dense() + + # For split operators, post conditioned should be symmetric. + if "_split" in preconditioner: + np.testing.assert_allclose( + post_conditioned, post_conditioned.T, rtol=1e-2, atol=1e-4) + + # log det post_conditioned = log det M - pc.log_det + _, post_log_det = jnp.linalg.slogdet(post_conditioned) + _, M_log_det = jnp.linalg.slogdet(M) + # TODO(thomaswc): Figure out why the precoditioner log det calculations + # are so wrong for partial_cholesky and partial_lanczos. + if preconditioner not in [ + "partial_cholesky", "partial_lanczos", "truncated_randomized_svd"]: + self.assertAlmostEqual(post_log_det, M_log_det - pc.log_det(), delta=1e-3) + + evalues = jnp.linalg.eigvalsh(post_conditioned) + post_cond_number = evalues[-1] / jnp.abs(evalues[0]) + + self.assertLess(post_cond_number, condition_number_bound) + + # For split operators, check that the post conditioned matrix is still + # positive definite. + lower_bound = 0.0 + if "_split" in preconditioner: + self.assertGreater(evalues[0], lower_bound) + + def test_are_pytrees(self): + M = jnp.array([[5.0, 1.0], [1.0, 9.0]], dtype=self.dtype) + for pc, expected_leaves in { + "auto": 2, + "identity": 1, + "diagonal": 1, + "rank_one": 2, + "partial_cholesky": 2, + "partial_lanczos": 2, + "truncated_svd": 2, + "diagonal_split": 1, + "rank_one_split": 2, + "partial_cholesky_split": 2, + "partial_lanczos_split": 2, + "truncated_svd_split": 2, + }.items(): + p = preconditioners.get_preconditioner( + pc, M, rank=5, key=jax.random.PRNGKey(6) + ) + self.assertLen(jax.tree_util.tree_leaves(p), expected_leaves, + f"Expected {expected_leaves} leaves for {pc}") + + def test_flatten_unflatten(self): + M = jnp.array([[5.0, 1.0], [1.0, 9.0]], dtype=self.dtype) + ip = preconditioners.IdentityPreconditioner(M) + leaves, treedef = jax.tree_util.tree_flatten(ip) + rep = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertIsInstance(rep, preconditioners.IdentityPreconditioner) + + dp = preconditioners.DiagonalSplitPreconditioner(M) + leaves, treedef = jax.tree_util.tree_flatten(dp) + rep = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertIsInstance(rep, preconditioners.DiagonalSplitPreconditioner) + np.testing.assert_allclose( + dp.right_half().to_dense(), rep.right_half().to_dense() + ) + + r1p = preconditioners.RankOneSplitPreconditioner( + M, key=jax.random.PRNGKey(7) + ) + leaves, treedef = jax.tree_util.tree_flatten(r1p) + rep = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertIsInstance(rep, preconditioners.RankOneSplitPreconditioner) + np.testing.assert_allclose( + r1p.right_half().to_dense(), rep.right_half().to_dense() + ) + + pcp = preconditioners.PartialCholeskySplitPreconditioner( + M, key=jax.random.PRNGKey(7) + ) + leaves, treedef = jax.tree_util.tree_flatten(pcp) + rep = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertIsInstance( + rep, preconditioners.PartialCholeskySplitPreconditioner + ) + np.testing.assert_allclose( + pcp.right_half().to_dense(), rep.right_half().to_dense() + ) + + plp = preconditioners.PartialLanczosSplitPreconditioner( + M, key=jax.random.PRNGKey(7) + ) + leaves, treedef = jax.tree_util.tree_flatten(plp) + rep = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertIsInstance( + rep, preconditioners.PartialLanczosSplitPreconditioner + ) + np.testing.assert_allclose( + plp.right_half().to_dense(), rep.right_half().to_dense() + ) + + tsvd = preconditioners.TruncatedSvdSplitPreconditioner( + M, key=jax.random.PRNGKey(7) + ) + leaves, treedef = jax.tree_util.tree_flatten(tsvd) + rep = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertIsInstance( + rep, preconditioners.TruncatedSvdSplitPreconditioner + ) + np.testing.assert_allclose( + tsvd.right_half().to_dense(), rep.right_half().to_dense() + ) + + @parameterized.parameters( + ("auto", 1), + ("auto", 2), + ("auto", 10), + ("identity", 1), + ("identity", 10), + ("diagonal", 1), + ("diagonal", 10), + ("rank_one", 1), + ("rank_one", 2), + ("rank_one", 10), + ("partial_cholesky", 1), + ("partial_cholesky", 10), + ("partial_lanczos", 1), + ("partial_lanczos", 2), + ("partial_lanczos", 10), + ("truncated_svd", 1), + ("truncated_svd", 2), + ("truncated_svd", 10), + ("truncated_randomized_svd", 1), + ("truncated_randomized_svd", 2), + ("truncated_randomized_svd", 10), + ("diagonal_split", 1), + ("diagonal_split", 10), + ("rank_one_split", 1), + ("rank_one_split", 2), + ("rank_one_split", 10), + ("partial_cholesky_split", 1), + ("partial_cholesky_split", 10), + ("partial_lanczos_split", 1), + ("partial_lanczos_split", 2), + # TODO(srvasude): Re-enable this when partial-lanczos has + # reorthogonalization added to it. + # ("partial_lanczos_split", 10), + ("truncated_svd_split", 1), + ("truncated_svd_split", 2), + ("truncated_svd_split", 10), + ) + def test_trace_of_inverse_product(self, preconditioner, n): + # Make a random symmetric positive definite matrix M. + A = jax.random.uniform(jax.random.PRNGKey(8), shape=(n, n), + minval=-1.0, maxval=1.0).astype(self.dtype) + M = A.T @ A + 0.6 * jnp.eye(n).astype(self.dtype) + # Make a random, not necessarily symmetric or positive definite matrix B. + B = jax.random.uniform(jax.random.PRNGKey(10), shape=(n, n), + minval=-1.0, maxval=1.0).astype(self.dtype) + p = preconditioners.get_preconditioner( + preconditioner, M, key=jax.random.PRNGKey(9), rank=5) + true_trace = jnp.trace(p.full_preconditioner().solve(B)) + error = abs(true_trace - p.trace_of_inverse_product(B)) + relative_error = error / true_trace + self.assertLess(relative_error, 0.001) + # AlmostEqual(true_trace, p.trace_of_inverse_product(B), places=2) + + @parameterized.parameters( + "auto", + "diagonal", + "identity", + "partial_cholesky", + "partial_lanczos", + "rank_one", + "truncated_svd", + "truncated_randomized_svd", + "diagonal_split", + "partial_cholesky_split", + "partial_lanczos_split", + "rank_one_split", + "truncated_svd_split" + ) + def test_preconditioner_with_linop(self, preconditioner): + # Make a random symmetric positive definite matrix M. + A = jax.random.uniform(jax.random.PRNGKey(8), shape=(2, 2), + minval=-1.0, maxval=1.0).astype(self.dtype) + M = A.T @ A + 0.6 * jnp.eye(2).astype(self.dtype) + M = jtf.linalg.LinearOperatorFullMatrix(M) + # There are no errors. + _ = preconditioners.get_preconditioner( + preconditioner, M, key=jax.random.PRNGKey(9), rank=5) + + +class PreconditionersTestFloat32(_PreconditionersTest): + dtype = np.float32 + + +class PreconditionersTestFloat64(_PreconditionersTest): + dtype = np.float64 + + +del _PreconditionersTest + + +if __name__ == "__main__": + config.update("jax_enable_x64", True) + absltest.main() diff --git a/tensorflow_probability/python/experimental/fastgp/schur_complement.py b/tensorflow_probability/python/experimental/fastgp/schur_complement.py new file mode 100644 index 0000000000..1e7e3990fa --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/schur_complement.py @@ -0,0 +1,232 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The SchurComplement kernel.""" + +import jax +import jax.numpy as jnp +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.substrates.jax.math.psd_kernels.internal import util + +jtf = tfp.tf2jax +parameter_properties = tfp.internal.parameter_properties + + +__all__ = [ + 'SchurComplement', +] + + +def _add_diagonal_shift(matrix, shift): + return matrix + shift[..., jnp.newaxis] * jnp.eye( + matrix.shape[-1], dtype=matrix.dtype) + + +def _compute_divisor_matrix( + base_kernel, diag_shift, fixed_inputs): + """Compute the modified kernel with respect to the fixed inputs.""" + divisor_matrix = base_kernel.matrix(fixed_inputs, fixed_inputs) + if diag_shift is not None: + broadcast_shape = tfp.internal.distribution_util.get_broadcast_shape( + divisor_matrix, diag_shift[..., jnp.newaxis, jnp.newaxis]) + divisor_matrix = jnp.broadcast_to(divisor_matrix, broadcast_shape) + divisor_matrix = _add_diagonal_shift( + divisor_matrix, diag_shift[..., jnp.newaxis]) + return divisor_matrix + + +class SchurComplement(tfp.math.psd_kernels.AutoCompositeTensorPsdKernel): + """The fast SchurComplement kernel. + + See tfp.math.psd_kernels.SchurComplement for more details. + + """ + + def __init__(self, + base_kernel, + fixed_inputs, + preconditioner_fn, + diag_shift=None): + """Construct a SchurComplement kernel instance. + + Args: + base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to + build the block matrices of which this kernel computes the Schur + complement. + fixed_inputs: A (nested) Tensor, representing a collection of inputs. The + Schur complement that this kernel computes comes from a block matrix, + whose bottom-right corner is derived from + `base_kernel.matrix(fixed_inputs, fixed_inputs)`, and whose top-right + and bottom-left pieces are constructed by computing the base_kernel + at pairs of input locations together with these `fixed_inputs`. + `fixed_inputs` is allowed to be an empty collection (either `None` or + having a zero shape entry), in which case the kernel falls back to + the trivial application of `base_kernel` to inputs. See class-level + docstring for more details on the exact computation this does; + `fixed_inputs` correspond to the `Z` structure discussed there. + `fixed_inputs` (or each of its nested components) is assumed to have + shape `[b1, ..., bB, N, f1, ..., fF]` where the `b`'s are batch shape + entries, the `f`'s are feature_shape entries, and `N` is the number + of fixed inputs. + preconditioner_fn: A function that applies an invertible linear + transformation to its input, designed to increase the rate of + convergence by decreasing the condition number. The preconditioner_fn + should act like left application of an n by n linear operator, i.e. + preconditioner_fn(n x m) should have shape n x m. + diag_shift: A floating point scalar to be added to the diagonal of the + divisor_matrix. + """ + # TODO(srvasude): Support masking. + parameters = dict(locals()) + + if jax.tree_util.treedef_is_leaf( + jax.tree_util.tree_structure(base_kernel.feature_ndims)): + dtype = tfp.internal.dtype_util.common_dtype( + [base_kernel, fixed_inputs], + dtype_hint=tfp.internal.nest_util.broadcast_structure( + base_kernel.feature_ndims, jnp.float32)) + else: + # If the fixed inputs are not nested, we assume they are of the same + # float dtype as the remaining parameters. + dtype = tfp.internal.dtype_util.common_dtype( + [base_kernel, + fixed_inputs, + diag_shift], jnp.float32) + + self._base_kernel = base_kernel + self._diag_shift = diag_shift + self._fixed_inputs = fixed_inputs + self._preconditioner_fn = preconditioner_fn + + super(SchurComplement, self).__init__( + base_kernel.feature_ndims, + dtype=dtype, + name='SchurComplement', + parameters=parameters) + + def _is_fixed_inputs_empty(self): + # If fixed_inputs are `None` or have size 0, we consider this empty and fall + # back to (cheaper) trivial behavior. + if self._fixed_inputs is None: + return True + num_fixed_inputs = jax.tree_util.tree_map( + lambda t, nd: t.shape[-(nd + 1)], + self._fixed_inputs, self._base_kernel.feature_ndims) + if all(n is not None and n == 0 for n in jax.tree_util.tree_leaves( + num_fixed_inputs)): + return True + return False + + def _apply(self, x1, x2, example_ndims): + # In the shape annotations below, + # + # - x1 has shape B1 + E1 + F (batch, example, feature), + # - x2 has shape B2 + E2 + F, + # - z refers to self.fixed_inputs, and has shape Bz + [ez] + F, ie its + # example ndims is exactly 1, + # - self.base_kernel has batch shape Bk, + # - bc(A, B, C) means "the result of broadcasting shapes A, B, and C". + fixed_inputs = self._fixed_inputs + + # Shape: bc(Bk, B1, B2) + bc(E1, E2) + k12 = self.base_kernel.apply(x1, x2, example_ndims) + if self._is_fixed_inputs_empty(): + return k12 + + # Shape: bc(Bk, B1, Bz) + E1 + [ez] + k1z = self.base_kernel.tensor(x1, fixed_inputs, + x1_example_ndims=example_ndims, + x2_example_ndims=1) + # Shape: bc(Bk, B2, Bz) + E2 + [ez] + k2z = self.base_kernel.tensor(x2, fixed_inputs, + x1_example_ndims=example_ndims, + x2_example_ndims=1) + k2z = jnp.reshape(k2z, [-1, k2z.shape[-1]]) + # Shape: bc(Bz, Bk) + [ez, ez] + div_mat = self._divisor_matrix(fixed_inputs=fixed_inputs) + + div_mat = util.pad_shape_with_ones(div_mat, example_ndims - 1, -3) + + kzzinv_kz2, _ = mbcg.modified_batched_conjugate_gradients( + lambda x: div_mat @ x, + jnp.transpose(k2z), + preconditioner_fn=self.preconditioner_fn, + max_iters=20) + kzzinv_kz2 = jnp.transpose(kzzinv_kz2) + k1z_kzzinv_kz2 = jnp.sum(k1z * kzzinv_kz2, axis=-1) + + return k12 - k1z_kzzinv_kz2 + + def _matrix(self, x1, x2): + k12 = self.base_kernel.matrix(x1, x2) + if self._is_fixed_inputs_empty(): + return k12 + fixed_inputs = self._fixed_inputs + + # Shape: bc(Bk, B1, Bz) + [e1] + [ez] + k1z = self.base_kernel.matrix(x1, fixed_inputs) + + # Shape: bc(Bk, B2, Bz) + [e2] + [ez] + k2z = self.base_kernel.matrix(x2, fixed_inputs) + + # Shape: bc(Bz, Bk) + [ez, ez] + div_mat = self._divisor_matrix(fixed_inputs=fixed_inputs) + + kzzinv_kz2, _ = mbcg.modified_batched_conjugate_gradients( + lambda x: div_mat @ x, + jnp.transpose(k2z), + preconditioner_fn=self.preconditioner_fn, + max_iters=20) + + k1z_kzzinv_kz2 = k1z @ kzzinv_kz2 + return k12 - k1z_kzzinv_kz2 + + @property + def fixed_inputs(self): + return self._fixed_inputs + + @property + def base_kernel(self): + return self._base_kernel + + @property + def diag_shift(self): + return self._diag_shift + + @property + def preconditioner_fn(self): + return self._preconditioner_fn + + @classmethod + def _parameter_properties(cls, dtype): + return dict( + base_kernel=parameter_properties.BatchedComponentProperties(), + fixed_inputs=parameter_properties.ParameterProperties( + event_ndims=lambda self: jax.tree_util.tree_map( # pylint: disable=g-long-lambda + lambda nd: nd + 1, self.base_kernel.feature_ndims)), + diag_shift=parameter_properties.ParameterProperties( + default_constraining_bijector_fn=( + lambda: tfp.bijectors.Softplus( # pylint: disable=g-long-lambda + low=tfp.internal.dtype_util.eps(dtype))))) + + def _divisor_matrix(self, fixed_inputs=None): + fixed_inputs = self._fixed_inputs if fixed_inputs is None else fixed_inputs + return _compute_divisor_matrix( + self._base_kernel, + diag_shift=self._diag_shift, + fixed_inputs=fixed_inputs) + + def divisor_matrix(self): + return self._divisor_matrix() diff --git a/tensorflow_probability/python/experimental/fastgp/schur_complement_test.py b/tensorflow_probability/python/experimental/fastgp/schur_complement_test.py new file mode 100644 index 0000000000..3312894a01 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/schur_complement_test.py @@ -0,0 +1,97 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for schur_complement.py.""" +from absl.testing import parameterized +import jax +from jax import config +import numpy as np +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.python.experimental.fastgp import schur_complement +from tensorflow_probability.substrates import jax as tfp +from absl.testing import absltest + + +class _SchurComplementTest(parameterized.TestCase): + + @parameterized.parameters( + {'dims': 3}, + {'dims': 4}, + {'dims': 5}, + {'dims': 7}, + {'dims': 11}) + def testValuesAreCorrect(self, dims): + np.random.seed(42) + num_obs = 5 + num_x = 3 + num_y = 7 + + shape = [dims] + + base_kernel = tfp.math.psd_kernels.ExponentiatedQuadratic( + self.dtype(5.), self.dtype(1.), feature_ndims=1) + + fixed_inputs = np.random.uniform( + -1., 1., size=[num_obs] + shape).astype(self.dtype) + + expected_k = tfp.math.psd_kernels.SchurComplement( + base_kernel=base_kernel, fixed_inputs=fixed_inputs) + + # Can use a dummy matrix since this is ignored. + pc = preconditioners.IdentityPreconditioner( + np.ones([num_obs, num_obs])) + + actual_k = schur_complement.SchurComplement( + base_kernel=base_kernel, + preconditioner_fn=pc.full_preconditioner().solve, + fixed_inputs=fixed_inputs) + + for i in range(5): + x = jax.random.uniform( + jax.random.PRNGKey(i), + minval=-1, + maxval=1, + shape=[num_x] + shape).astype(self.dtype) + y1 = jax.random.uniform( + jax.random.PRNGKey( + 2 * i), minval=-1, maxval=1, shape=[num_x] + shape).astype( + self.dtype) + y2 = jax.random.uniform( + jax.random.PRNGKey( + 2 * i + 1), minval=-1, maxval=1, shape=[num_y] + shape).astype( + self.dtype) + np.testing.assert_allclose( + expected_k.apply(x, y1, example_ndims=1), + actual_k.apply(x, y1, example_ndims=1), + rtol=6e-4) + np.testing.assert_allclose( + expected_k.matrix(x, y2), + actual_k.matrix(x, y2), + rtol=3e-3) + + +class SchurComplementTestFloat32(_SchurComplementTest): + dtype = np.float32 + + +class SchurComplementTestFloat64(_SchurComplementTest): + dtype = np.float64 + + +del _SchurComplementTest + + +if __name__ == '__main__': + config.update('jax_enable_x64', True) + absltest.main()