diff --git a/tensorflow_probability/python/experimental/fastgp/BUILD b/tensorflow_probability/python/experimental/fastgp/BUILD index f956e86750..f27a3bd7f2 100644 --- a/tensorflow_probability/python/experimental/fastgp/BUILD +++ b/tensorflow_probability/python/experimental/fastgp/BUILD @@ -28,6 +28,36 @@ package( ], ) +py_library( + name = "fastgp.jax", + srcs = ["__init__.py"], + deps = [ + ":fast_gp", + ":fast_gprm", + ":fast_log_det", + ":fast_mtgp", + ":linalg", + ":linear_operator_sum", + ":mbcg", + ":partial_lanczos", + ":preconditioners", + ":schur_complement", + "//tensorflow_probability/python/internal:all_util", + ], +) + +# Dummy libraries to satisfy the multi_substrate_py_library deps of +# tfp/python/experimental:experimental. +py_library( + name = "fastgp", + deps = [], +) + +py_library( + name = "fastgp.numpy", + deps = [], +) + py_library( name = "mbcg", srcs = ["mbcg.py"], @@ -55,7 +85,16 @@ py_library( ":mbcg", ":preconditioners", # jax dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/bijectors:softplus.jax", + "//tensorflow_probability/python/distributions:distribution.jax", + "//tensorflow_probability/python/distributions:gaussian_process_regression_model.jax", + "//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax", + "//tensorflow_probability/python/internal:dtype_util.jax", + "//tensorflow_probability/python/internal:parameter_properties.jax", + "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:tensor_util.jax", + "//tensorflow_probability/python/internal/backend/jax", + "//tensorflow_probability/python/mcmc:sample_halton_sequence.jax", ], ) @@ -83,7 +122,14 @@ py_library( ":mbcg", ":preconditioners", # jax dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/distributions:distribution.jax", + "//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax", + "//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel.jax", + "//tensorflow_probability/python/internal:dtype_util.jax", + "//tensorflow_probability/python/internal:prefer_static.jax", + "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:tensor_util.jax", + "//tensorflow_probability/python/internal/backend/jax", ], ) @@ -108,7 +154,11 @@ py_library( ":preconditioners", ":schur_complement", # jax dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/bijectors:softplus.jax", + "//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax", + "//tensorflow_probability/python/internal:dtype_util.jax", + "//tensorflow_probability/python/internal:nest_util.jax", + "//tensorflow_probability/python/internal:parameter_properties.jax", ], ) @@ -133,7 +183,7 @@ py_library( # jax dep, # jax:experimental_sparse dep, # jaxtyping dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/internal/backend/jax", ], ) @@ -157,7 +207,7 @@ py_library( ":mbcg", # jax dep, # scipy dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/internal/backend/jax", ], ) @@ -165,10 +215,12 @@ py_test( name = "partial_lanczos_test", srcs = ["partial_lanczos_test.py"], deps = [ + ":mbcg", ":partial_lanczos", # absl/testing:absltest dep, # jax dep, # numpy dep, + "//tensorflow_probability/substrates:jax", ], ) @@ -183,7 +235,6 @@ py_library( # jaxtyping dep, # numpy dep, # scipy dep, - "//tensorflow_probability/substrates:jax", ], ) @@ -206,7 +257,6 @@ py_library( name = "linear_operator_sum", srcs = ["linear_operator_sum.py"], deps = [ - "//tensorflow_probability/substrates:jax", ], ) @@ -219,7 +269,8 @@ py_library( # jax dep, # jax:experimental_sparse dep, # jaxtyping dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/internal/backend/jax", + "//tensorflow_probability/python/math:linalg.jax", ], ) @@ -232,7 +283,7 @@ py_test( # absl/testing:absltest dep, # absl/testing:parameterized dep, # jax dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/internal/backend/jax", ], ) @@ -242,7 +293,10 @@ py_library( deps = [ ":preconditioners", # jax dep, - "//tensorflow_probability/substrates:jax", + "//tensorflow_probability/python/bijectors:softplus.jax", + "//tensorflow_probability/python/internal:distribution_util.jax", + "//tensorflow_probability/python/internal:dtype_util.jax", + "//tensorflow_probability/python/internal:nest_util.jax", ], ) diff --git a/tensorflow_probability/python/experimental/fastgp/__init__.py b/tensorflow_probability/python/experimental/fastgp/__init__.py new file mode 100644 index 0000000000..828ff72c70 --- /dev/null +++ b/tensorflow_probability/python/experimental/fastgp/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Package for training Gaussian Processes in time less than O(n^3).""" + +from tensorflow_probability.python.experimental.fastgp import fast_gp +from tensorflow_probability.python.experimental.fastgp import fast_gprm +from tensorflow_probability.python.experimental.fastgp import fast_log_det +from tensorflow_probability.python.experimental.fastgp import fast_mtgp +from tensorflow_probability.python.experimental.fastgp import linalg +from tensorflow_probability.python.experimental.fastgp import linear_operator_sum +from tensorflow_probability.python.experimental.fastgp import mbcg +from tensorflow_probability.python.experimental.fastgp import partial_lanczos +from tensorflow_probability.python.experimental.fastgp import preconditioners +from tensorflow_probability.python.experimental.fastgp import schur_complement +from tensorflow_probability.python.internal import all_util + +_allowed_symbols = [ + 'fast_log_det', + 'fast_gp', + 'fast_gprm', + 'fast_mtgp', + 'linalg', + 'linear_operator_sum', + 'mbcg', + 'partial_lanczos', + 'preconditioners', + 'schur_complement', +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp.py b/tensorflow_probability/python/experimental/fastgp/fast_gp.py index 79ea875948..b7c299af2a 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_gp.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp.py @@ -25,11 +25,16 @@ from tensorflow_probability.python.experimental.fastgp import fast_log_det from tensorflow_probability.python.experimental.fastgp import mbcg from tensorflow_probability.python.experimental.fastgp import preconditioners -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.python.internal import reparameterization +from tensorflow_probability.python.internal.backend import jax as tf2jax +from tensorflow_probability.substrates.jax.bijectors import softplus +from tensorflow_probability.substrates.jax.distributions import distribution +from tensorflow_probability.substrates.jax.distributions import gaussian_process_regression_model +from tensorflow_probability.substrates.jax.distributions.internal import stochastic_process_util +from tensorflow_probability.substrates.jax.internal import dtype_util +from tensorflow_probability.substrates.jax.internal import parameter_properties +from tensorflow_probability.substrates.jax.internal import tensor_util -tfb = tfp.bijectors -tfd = tfp.distributions -jtf = tfp.tf2jax Array = jnp.ndarray LOG_TWO_PI = 1.8378770664093453 @@ -76,7 +81,7 @@ class GaussianProcessConfig: log_det_iters: int = 20 -class GaussianProcess(tfp.distributions.Distribution): +class GaussianProcess(distribution.Distribution): """Fast, JAX-only implementation of a GP distribution class. See tfd.distributions.GaussianProcess for a description and parameter @@ -131,26 +136,28 @@ def __init__( jax.tree_util.tree_structure(kernel.feature_ndims)): # If the index points are not nested, we assume they are of the same # float dtype as the GP. - dtype = tfp.internal.dtype_util.common_dtype( - {'index_points': index_points, - 'observation_noise_variance': observation_noise_variance, - 'jitter': jitter}, - jnp.float32) + dtype = dtype_util.common_dtype( + { + 'index_points': index_points, + 'observation_noise_variance': observation_noise_variance, + 'jitter': jitter, + }, + jnp.float32, + ) else: - dtype = tfp.internal.dtype_util.common_dtype( - {'observation_noise_variance': observation_noise_variance, - 'jitter': jitter}, - jnp.float32) + dtype = dtype_util.common_dtype( + { + 'observation_noise_variance': observation_noise_variance, + 'jitter': jitter, + }, + jnp.float32, + ) self._kernel = kernel self._index_points = index_points - self._mean_fn = tfd.internal.stochastic_process_util.maybe_create_mean_fn( - mean_fn, dtype - ) - self._observation_noise_variance = ( - tfp.internal.tensor_util.convert_nonref_to_tensor( - observation_noise_variance - ) + self._mean_fn = stochastic_process_util.maybe_create_mean_fn(mean_fn, dtype) + self._observation_noise_variance = tensor_util.convert_nonref_to_tensor( + observation_noise_variance ) self._jitter = jitter self._config = config @@ -162,11 +169,12 @@ def __init__( super(GaussianProcess, self).__init__( dtype=dtype, - reparameterization_type=tfd.FULLY_REPARAMETERIZED, + reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, - name='GaussianProcess') + name='GaussianProcess', + ) @property def kernel(self): @@ -191,24 +199,29 @@ def jitter(self): @classmethod def _parameter_properties(cls, dtype, num_classes=None): return dict( - index_points=tfp.util.ParameterProperties( + index_points=parameter_properties.ParameterProperties( event_ndims=lambda self: jax.tree_util.tree_map( # pylint: disable=g-long-lambda - lambda nd: nd + 1, self.kernel.feature_ndims), - shape_fn=tfp.internal.parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, + lambda nd: nd + 1, self.kernel.feature_ndims + ), + shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, ), - kernel=tfp.util.BatchedComponentProperties(), + kernel=parameter_properties.BatchedComponentProperties(), observation_noise_variance=( - tfp.util.ParameterProperties( + parameter_properties.ParameterProperties( event_ndims=0, shape_fn=lambda sample_shape: sample_shape[:-1], default_constraining_bijector_fn=( - lambda: tfb.Softplus( # pylint: disable=g-long-lambda - low=tfp.internal.dtype_util.eps(dtype)))))) + lambda: softplus.Softplus( # pylint: disable=g-long-lambda + low=dtype_util.eps(dtype) + ) + ), + ) + ), + ) @property def event_shape(self): - return tfd.internal.stochastic_process_util.event_shape( - self._kernel, self.index_points) + return stochastic_process_util.event_shape(self._kernel, self.index_points) def _mean(self): mean = self._mean_fn(self._index_points) @@ -217,12 +230,12 @@ def _mean(self): def _covariance(self): index_points = self._index_points - _, covariance = ( - tfd.internal.stochastic_process_util.get_loc_and_kernel_matrix( - kernel=self.kernel, - mean_fn=self._mean_fn, - observation_noise_variance=self.observation_noise_variance, - index_points=index_points)) + _, covariance = stochastic_process_util.get_loc_and_kernel_matrix( + kernel=self.kernel, + mean_fn=self._mean_fn, + observation_noise_variance=self.observation_noise_variance, + index_points=index_points, + ) return covariance def _variance(self): @@ -236,15 +249,13 @@ def _log_det(self, key, is_missing=None): # TODO(thomaswc): Considering caching loc and covariance for a given # is_missing. - loc, covariance = ( - tfd.internal.stochastic_process_util.get_loc_and_kernel_matrix( - kernel=self._kernel, - mean_fn=self._mean_fn, - observation_noise_variance=self.dtype(0.), - index_points=self._index_points, - is_missing=is_missing, - mask_loc=False, - ) + loc, covariance = stochastic_process_util.get_loc_and_kernel_matrix( + kernel=self._kernel, + mean_fn=self._mean_fn, + observation_noise_variance=self.dtype(0.0), + index_points=self._index_points, + is_missing=is_missing, + mask_loc=False, ) is_scaling_preconditioner = self._config.preconditioner.endswith('scaling') @@ -340,7 +351,7 @@ def posterior_predictive( **kwargs ): # TODO(thomaswc): Speed this up, if possible. - return tfd.GaussianProcessRegressionModel.precompute_regression_model( + return gaussian_process_regression_model.GaussianProcessRegressionModel.precompute_regression_model( kernel=self._kernel, observation_index_points=self._index_points, observations=observations, @@ -350,7 +361,7 @@ def posterior_predictive( mean_fn=self._mean_fn, cholesky_fn=None, jitter=self._jitter, - **kwargs + **kwargs, ) @@ -359,18 +370,18 @@ def posterior_predictive( @functools.partial(jax.custom_jvp, nondiff_argnums=(3,)) def yt_inv_y( - kernel: jtf.linalg.LinearOperator, - preconditioner: jtf.linalg.LinearOperator, + kernel: tf2jax.linalg.LinearOperator, + preconditioner: tf2jax.linalg.LinearOperator, y: Array, max_iters: int = 20, ) -> Array: """Compute y^t (kernel)^(-1) y. Args: - kernel: A matrix or jtf.LinearOperator representing a linear map from R^n to - itself. - preconditioner: An operator on R^n that when applied before kernel, - reduces the condition number of the system. + kernel: A matrix or linalg.LinearOperator representing a linear map from R^n + to itself. + preconditioner: An operator on R^n that when applied before kernel, reduces + the condition number of the system. y: A matrix of shape (n, m). max_iters: The maximum number of iterations to perform the modified batched conjugate gradients algorithm for. @@ -410,7 +421,7 @@ def multiplier(B): primal_out = jnp.einsum('ij,ij->j', y, inv_y) tangent_out = 2.0 * jnp.einsum('ik,ik->k', inv_y, dy) - if isinstance(dkernel, jtf.linalg.LinearOperator): + if isinstance(dkernel, tf2jax.linalg.LinearOperator): tangent_out = tangent_out - jnp.einsum('ik,ik->k', inv_y, dkernel @ inv_y) else: tangent_out = tangent_out - jnp.einsum('ik,ij,jk->k', inv_y, dkernel, inv_y) diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gprm.py b/tensorflow_probability/python/experimental/fastgp/fast_gprm.py index 386089d30e..f498e5e851 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_gprm.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_gprm.py @@ -21,12 +21,11 @@ from tensorflow_probability.python.experimental.fastgp import mbcg from tensorflow_probability.python.experimental.fastgp import preconditioners from tensorflow_probability.python.experimental.fastgp import schur_complement -from tensorflow_probability.substrates import jax as tfp - -parameter_properties = tfp.internal.parameter_properties -tfd = tfp.distributions -jtf = tfp.tf2jax - +from tensorflow_probability.substrates.jax.bijectors import softplus +from tensorflow_probability.substrates.jax.distributions.internal import stochastic_process_util +from tensorflow_probability.substrates.jax.internal import dtype_util +from tensorflow_probability.substrates.jax.internal import nest_util +from tensorflow_probability.substrates.jax.internal import parameter_properties __all__ = [ 'GaussianProcessRegressionModel', @@ -123,23 +122,25 @@ def __init__( # TODO(srvasude): Add support for masking observations. In addition, cache # the observation matrix so that it isn't recomputed every iteration. parameters = dict(locals()) - input_dtype = tfp.internal.dtype_util.common_dtype( + input_dtype = dtype_util.common_dtype( dict( kernel=kernel, index_points=index_points, observation_index_points=observation_index_points, ), - dtype_hint=tfp.internal.nest_util.broadcast_structure( - kernel.feature_ndims, np.float32)) + dtype_hint=nest_util.broadcast_structure( + kernel.feature_ndims, np.float32 + ), + ) # If the input dtype is non-nested float, we infer a single dtype for the # input and the float parameters, which is also the dtype of the GP's # samples, log_prob, etc. If the input dtype is nested (or not float), we # do not use it to infer the GP's float dtype. - if (not jax.tree_util.treedef_is_leaf( - jax.tree_util.tree_structure(input_dtype)) and - tfp.internal.dtype_util.is_floating(input_dtype)): - dtype = tfp.internal.dtype_util.common_dtype( + if not jax.tree_util.treedef_is_leaf( + jax.tree_util.tree_structure(input_dtype) + ) and dtype_util.is_floating(input_dtype): + dtype = dtype_util.common_dtype( dict( kernel=kernel, index_points=index_points, @@ -153,13 +154,15 @@ def __init__( ) input_dtype = dtype else: - dtype = tfp.internal.dtype_util.common_dtype( + dtype = dtype_util.common_dtype( dict( observations=observations, observation_noise_variance=observation_noise_variance, predictive_noise_variance=predictive_noise_variance, jitter=jitter, - ), dtype_hint=np.float32) + ), + dtype_hint=np.float32, + ) if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance @@ -170,8 +173,7 @@ def __init__( observations, observation_index_points)) # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. - mean_fn = tfd.internal.stochastic_process_util.maybe_create_mean_fn( - mean_fn, dtype) + mean_fn = stochastic_process_util.maybe_create_mean_fn(mean_fn, dtype) self._observation_index_points = observation_index_points self._observations = observations @@ -231,14 +233,16 @@ def conditional_mean_fn(x): # Special logic for mean_fn only; SchurComplement already handles the # case of empty observations (ie, falls back to base_kernel). - if not tfd.internal.stochastic_process_util.is_empty_observation_data( + if not stochastic_process_util.is_empty_observation_data( feature_ndims=kernel.feature_ndims, observation_index_points=observation_index_points, - observations=observations): - tfd.internal.stochastic_process_util.validate_observation_data( + observations=observations, + ): + stochastic_process_util.validate_observation_data( kernel=kernel, observation_index_points=observation_index_points, - observations=observations) + observations=observations, + ) super(GaussianProcessRegressionModel, self).__init__( index_points=index_points, @@ -272,31 +276,39 @@ def _event_ndims_fn(self): return jax.tree_util.treep_map( lambda nd: nd + 1, self.kernel.feature_ndims) return dict( - index_points=tfp.util.ParameterProperties( + index_points=parameter_properties.ParameterProperties( event_ndims=_event_ndims_fn, shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, ), - observations=tfp.util.ParameterProperties( + observations=parameter_properties.ParameterProperties( event_ndims=1, - shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED), - observation_index_points=tfp.util.ParameterProperties( + shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, + ), + observation_index_points=parameter_properties.ParameterProperties( event_ndims=_event_ndims_fn, shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, ), - observations_is_missing=tfp.util.ParameterProperties( + observations_is_missing=parameter_properties.ParameterProperties( event_ndims=1, shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, ), - kernel=tfp.util.BatchedComponentProperties(), - observation_noise_variance=tfp.util.ParameterProperties( + kernel=parameter_properties.BatchedComponentProperties(), + observation_noise_variance=parameter_properties.ParameterProperties( event_ndims=0, shape_fn=lambda sample_shape: sample_shape[:-1], default_constraining_bijector_fn=( - lambda: tfp.bijectors.Softplus( # pylint:disable=g-long-lambda - low=tfp.internal.dtype_util.eps(dtype)))), - predictive_noise_variance=tfp.util.ParameterProperties( + lambda: softplus.Softplus( # pylint:disable=g-long-lambda + low=dtype_util.eps(dtype) + ) + ), + ), + predictive_noise_variance=parameter_properties.ParameterProperties( event_ndims=0, shape_fn=lambda sample_shape: sample_shape[:-1], default_constraining_bijector_fn=( - lambda: tfp.bijectors.Softplus( # pylint:disable=g-long-lambda - low=tfp.internal.dtype_util.eps(dtype))))) + lambda: softplus.Softplus( # pylint:disable=g-long-lambda + low=dtype_util.eps(dtype) + ) + ), + ), + ) diff --git a/tensorflow_probability/python/experimental/fastgp/fast_log_det.py b/tensorflow_probability/python/experimental/fastgp/fast_log_det.py index 92598d513f..ae4fc146b8 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_log_det.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_log_det.py @@ -27,9 +27,9 @@ from tensorflow_probability.python.experimental.fastgp import mbcg from tensorflow_probability.python.experimental.fastgp import partial_lanczos from tensorflow_probability.python.experimental.fastgp import preconditioners -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.python.internal.backend import jax as tf2jax +from tensorflow_probability.substrates.jax.mcmc import sample_halton_sequence_lib -jtf = tfp.tf2jax Array = jnp.ndarray # pylint: disable=invalid-name @@ -75,11 +75,9 @@ def make_probe_vectors( return q * norm if probe_vector_type == ProbeVectorType.NORMAL_QMC: - uniforms = tfp.mcmc.sample_halton_sequence( - dim=n, - num_results=num_probe_vectors, - dtype=dtype, - seed=key) + uniforms = sample_halton_sequence_lib.sample_halton_sequence( + dim=n, num_results=num_probe_vectors, dtype=dtype, seed=key + ) return jnp.transpose(jax.scipy.special.ndtri(uniforms)) raise ValueError( @@ -292,7 +290,7 @@ def _log_det_rational_approx_with_hutchinson( @functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) def _r1( - unused_M: jtf.linalg.LinearOperator, + unused_M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, key: jax.Array, @@ -312,7 +310,7 @@ def _r1( @functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) def _r2( - unused_M: jtf.linalg.LinearOperator, + unused_M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, key: jax.Array, @@ -332,7 +330,7 @@ def _r2( @functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) def _r3( - unused_M: jtf.linalg.LinearOperator, + unused_M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, key: jax.Array, @@ -352,7 +350,7 @@ def _r3( @functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) def _r4( - unused_M: jtf.linalg.LinearOperator, + unused_M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, key: jax.Array, @@ -372,7 +370,7 @@ def _r4( @functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) def _r5( - unused_M: jtf.linalg.LinearOperator, + unused_M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, key: jax.Array, @@ -392,7 +390,7 @@ def _r5( @functools.partial(jax.custom_jvp, nondiff_argnums=(4,)) def _r6( - unused_M: jtf.linalg.LinearOperator, + unused_M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, key: jax.Array, @@ -448,7 +446,7 @@ def _r6_jvp(num_iters, primals, tangents): @jax.named_call def r1( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -468,7 +466,7 @@ def r1( @jax.named_call def r2( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -488,7 +486,7 @@ def r2( @jax.named_call def r3( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -508,7 +506,7 @@ def r3( @jax.named_call def r4( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -528,7 +526,7 @@ def r4( @jax.named_call def r5( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -548,7 +546,7 @@ def r5( @jax.named_call def r6( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -593,7 +591,7 @@ def batch_log00(ts: mbcg.SymmetricTridiagonalMatrix) -> Array: jax.jit, static_argnames=['probe_vectors_are_rademacher', 'num_iters'] ) def _stochastic_lanczos_quadrature_log_det( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, probe_vectors: Array, unused_key, @@ -634,7 +632,7 @@ def _stochastic_lanczos_quadrature_log_det( @jax.named_call def stochastic_lanczos_quadrature_log_det( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, preconditioner: preconditioners.Preconditioner, key: jax.Array, num_probe_vectors: int = 25, @@ -701,7 +699,7 @@ def get_log_det_algorithm(alg_name: str): def log_det_taylor_series_with_hutchinson( - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, num_probe_vectors: int, key: jax.Array, num_taylor_series_iterations: int = 10, diff --git a/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py b/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py index 2351621ca3..56c6149cb9 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_mtgp.py @@ -20,13 +20,16 @@ from tensorflow_probability.python.experimental.fastgp import fast_log_det from tensorflow_probability.python.experimental.fastgp import linear_operator_sum from tensorflow_probability.python.experimental.fastgp import preconditioners -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.python.internal import reparameterization +from tensorflow_probability.python.internal.backend import jax as tf2jax +from tensorflow_probability.substrates.jax.distributions import distribution +from tensorflow_probability.substrates.jax.distributions.internal import stochastic_process_util +from tensorflow_probability.substrates.jax.experimental.psd_kernels import multitask_kernel +from tensorflow_probability.substrates.jax.internal import dtype_util +from tensorflow_probability.substrates.jax.internal import prefer_static as ps +from tensorflow_probability.substrates.jax.internal import tensor_util + -ps = tfp.internal.prefer_static -tfd = tfp.distributions -tfed = tfp.experimental.distributions -tfek = tfp.experimental.psd_kernels -jtf = tfp.tf2jax Array = jnp.ndarray @@ -43,11 +46,11 @@ def _unvec(x, matrix_shape): return jnp.reshape(x, ps.concat([ps.shape(x)[:-1], matrix_shape], axis=0)) -class MultiTaskGaussianProcess(tfd.AutoCompositeTensorDistribution): +class MultiTaskGaussianProcess(distribution.AutoCompositeTensorDistribution): """Fast, JAX-only implementation of a MTGP distribution class. - See tfed.distributions.MultiTaskGaussianProcess for a description and - parameter documentation. + See tfp.experimental.distributions.MultiTaskGaussianProcess for a description + and parameter documentation. """ def __init__( @@ -87,24 +90,26 @@ def __init__( jax.tree_util.tree_structure(kernel.feature_ndims)): # If the index points are not nested, we assume they are of the same # float dtype as the GP. - dtype = tfp.internal.dtype_util.common_dtype( - {'index_points': index_points, - 'observation_noise_variance': observation_noise_variance}, - jnp.float32) + dtype = dtype_util.common_dtype( + { + 'index_points': index_points, + 'observation_noise_variance': observation_noise_variance, + }, + jnp.float32, + ) else: - dtype = tfp.internal.dtype_util.common_dtype( + dtype = dtype_util.common_dtype( {'observation_noise_variance': observation_noise_variance}, - jnp.float32) + jnp.float32, + ) self._kernel = kernel self._index_points = index_points - self._mean_fn = ( - tfd.internal.stochastic_process_util.maybe_create_multitask_mean_fn( - mean_fn, kernel, dtype)) - self._observation_noise_variance = ( - tfp.internal.tensor_util.convert_nonref_to_tensor( - observation_noise_variance - ) + self._mean_fn = stochastic_process_util.maybe_create_multitask_mean_fn( + mean_fn, kernel, dtype + ) + self._observation_noise_variance = tensor_util.convert_nonref_to_tensor( + observation_noise_variance ) self._config = config self._probe_vector_type = fast_log_det.ProbeVectorType[ @@ -114,11 +119,12 @@ def __init__( super(MultiTaskGaussianProcess, self).__init__( dtype=dtype, - reparameterization_type=tfd.FULLY_REPARAMETERIZED, + reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, - name='MultiTaskGaussianProcess') + name='MultiTaskGaussianProcess', + ) @property def kernel(self): @@ -138,8 +144,9 @@ def observation_noise_variance(self): @property def event_shape(self): - return tfd.internal.stochastic_process_util.multitask_event_shape( - self._kernel, self.index_points) + return stochastic_process_util.multitask_event_shape( + self._kernel, self.index_points + ) def _mean(self): loc = self._mean_fn(self._index_points) @@ -151,7 +158,7 @@ def _variance(self): index_points, index_points) observation_noise_variance = self.observation_noise_variance # We can add the observation noise to each block. - if isinstance(self.kernel, tfek.Independent): + if isinstance(self.kernel, multitask_kernel.Independent): single_task_variance = kernel_matrix.operators[0].diag_part() if observation_noise_variance is not None: single_task_variance = ( @@ -217,11 +224,13 @@ def get_preconditioner(cov): if is_scaling_preconditioner: preconditioner = get_preconditioner(covariance) - covariance = linear_operator_sum.LinearOperatorSum( - [covariance, - jtf.linalg.LinearOperatorScaledIdentity( - num_rows=covariance.range_dimension, - multiplier=self._observation_noise_variance)]) + covariance = linear_operator_sum.LinearOperatorSum([ + covariance, + tf2jax.linalg.LinearOperatorScaledIdentity( + num_rows=covariance.range_dimension, + multiplier=self._observation_noise_variance, + ), + ]) if not is_scaling_preconditioner: preconditioner = get_preconditioner(covariance) diff --git a/tensorflow_probability/python/experimental/fastgp/linalg.py b/tensorflow_probability/python/experimental/fastgp/linalg.py index 59b3a30f45..4c28864bf8 100644 --- a/tensorflow_probability/python/experimental/fastgp/linalg.py +++ b/tensorflow_probability/python/experimental/fastgp/linalg.py @@ -21,22 +21,21 @@ from jaxtyping import Float import numpy as np from tensorflow_probability.python.experimental.fastgp import partial_lanczos -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.python.internal.backend import jax as tf2jax -jtf = tfp.tf2jax Array = jnp.ndarray # pylint: disable=invalid-name def _matvec(M, x) -> jax.Array: - if isinstance(M, jtf.linalg.LinearOperator): + if isinstance(M, tf2jax.linalg.LinearOperator): return M.matvec(x) return M @ x def largest_eigenvector( - M: jtf.linalg.LinearOperator, key: jax.Array, num_iters: int = 10 + M: tf2jax.linalg.LinearOperator, key: jax.Array, num_iters: int = 10 ) -> tuple[Float, Array]: """Returns the largest (eigenvalue, eigenvector) of M.""" n = M.shape[-1] @@ -52,10 +51,11 @@ def largest_eigenvector( def make_randomized_truncated_svd( key: jax.Array, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, rank: int = 20, oversampling: int = 10, - num_iters: int = 4) -> tuple[Float, Array]: + num_iters: int = 4, +) -> tuple[Float, Array]: """Returns approximate SVD for symmetric `M`.""" # This is based on: # N. Halko, P.G. Martinsson, J. A. Tropp @@ -91,7 +91,8 @@ def make_randomized_truncated_svd( def make_partial_lanczos( - key: jax.Array, M: jtf.linalg.LinearOperator, rank: int) -> Array: + key: jax.Array, M: tf2jax.linalg.LinearOperator, rank: int +) -> Array: """Return low rank approximation to M based on the partial Lancozs alg.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -118,7 +119,8 @@ def make_partial_lanczos( def make_truncated_svd( - key, M: jtf.linalg.LinearOperator, rank: int, num_iters: int) -> Array: + key, M: tf2jax.linalg.LinearOperator, rank: int, num_iters: int +) -> Array: """Return low rank approximation to M based on the partial SVD alg.""" n = M.shape[-1] if 5 * rank >= n: @@ -138,7 +140,8 @@ def make_truncated_svd( @functools.partial(jax.jit, static_argnums=1) def make_partial_pivoted_cholesky( - M: jtf.linalg.LinearOperator, rank: int) -> Array: + M: tf2jax.linalg.LinearOperator, rank: int +) -> Array: """Return low rank approximation to M based on partial pivoted Cholesky.""" n = M.shape[-1] diff --git a/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py b/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py index 7c6dd52316..f4fc16525c 100644 --- a/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py +++ b/tensorflow_probability/python/experimental/fastgp/linear_operator_sum.py @@ -15,13 +15,11 @@ """Expresses a sum of operators.""" import jax -from tensorflow_probability.substrates import jax as tfp - -jtf = tfp.tf2jax +from tensorflow_probability.python.internal.backend import jax as tf2jax @jax.tree_util.register_pytree_node_class -class LinearOperatorSum(jtf.linalg.LinearOperator): +class LinearOperatorSum(tf2jax.linalg.LinearOperator): """Encapsulates a sum of linear operators.""" def __init__(self, diff --git a/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py b/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py index aa61e7385c..5f0c32c374 100644 --- a/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py +++ b/tensorflow_probability/python/experimental/fastgp/partial_lanczos.py @@ -20,9 +20,9 @@ import jax.numpy as jnp import scipy from tensorflow_probability.python.experimental.fastgp import mbcg -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.python.internal.backend import jax as tf2jax + -jtf = tfp.tf2jax Array = jnp.ndarray # pylint: disable=invalid-name @@ -155,9 +155,8 @@ def scan_func(loop_info, unused_x): def make_lanczos_preconditioner( - kernel: jtf.linalg.LinearOperator, - key: jax.Array, - num_iters: int = 20): + kernel: tf2jax.linalg.LinearOperator, key: jax.Array, num_iters: int = 20 +): """Return a preconditioner as a linear operator.""" n = kernel.shape[-1] key1, key2 = jax.random.split(key) @@ -166,8 +165,9 @@ def make_lanczos_preconditioner( Q, T = partial_lanczos(lambda x: kernel @ x, v, key2, num_iters) # Now diagonalize T as Q^t D Q - # TODO(thomaswc): Replace this scipy call with jnp.linalg.eigh so that - # it can be jit-ed. + # TODO(thomaswc): Once jax.scipy.linalg.eigh_tridiagonal supports + # eigenvectors (https://github.com/google/jax/issues/14019), replace + # this with that so that it can be jit-ed. evalues, evectors = scipy.linalg.eigh_tridiagonal( T.diag[0, :], T.off_diag[0, :]) sqrt_evalues = jnp.sqrt(evalues) @@ -176,17 +176,17 @@ def make_lanczos_preconditioner( # diag(F^t F)_i = sum_k (F^t)_{i, k} F_{k, i} = sum_k F_{k, i}^2 diag_Ft_F = jnp.sum(F * F, axis=0) - residual_diag = jtf.linalg.diag_part(kernel) - diag_Ft_F + residual_diag = tf2jax.linalg.diag_part(kernel) - diag_Ft_F eps = jnp.finfo(kernel.dtype).eps # TODO(srvasude): Modify this when residual_diag is near zero. This means that # we captured the diagonal appropriately, and modifying with a shift of eps # can alter the preconditioner greatly. - diag_linop = tfp.tf2jax.linalg.LinearOperatorDiag( - jnp.maximum(residual_diag, 0.0) + 10. * eps, is_positive_definite=True + diag_linop = tf2jax.linalg.LinearOperatorDiag( + jnp.maximum(residual_diag, 0.0) + 10.0 * eps, is_positive_definite=True ) - return tfp.tf2jax.linalg.LinearOperatorLowRankUpdate( + return tf2jax.linalg.LinearOperatorLowRankUpdate( diag_linop, jnp.transpose(F), is_positive_definite=True ) diff --git a/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py b/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py index 31f58f5d37..89d4b8204c 100644 --- a/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py +++ b/tensorflow_probability/python/experimental/fastgp/partial_lanczos_test.py @@ -26,15 +26,22 @@ class _PartialLanczosTest(absltest.TestCase): + def test_gram_schmidt(self): w = jnp.ones((5, 1), dtype=self.dtype) v = partial_lanczos.gram_schmidt( - jnp.array([[[1.0, 0, 0, 0, 0], - [0, 1.0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]], dtype=self.dtype), - w) + jnp.array( + [[ + [1.0, 0, 0, 0, 0], + [0, 1.0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ]], + dtype=self.dtype, + ), + w, + ) self.assertEqual((5, 1), v.shape) self.assertEqual(0.0, v[0][0]) self.assertEqual(0.0, v[1][0]) @@ -44,18 +51,21 @@ def test_partial_lanczos_identity(self): A = jnp.identity(10).astype(self.dtype) v = jnp.ones((10, 1)).astype(self.dtype) Q, T = partial_lanczos.partial_lanczos( - lambda x: A @ x, v, jax.random.PRNGKey(2), 10) + lambda x: A @ x, v, jax.random.PRNGKey(2), 10 + ) np.testing.assert_allclose(jnp.identity(10), Q[0] @ Q[0].T, atol=1e-6) np.testing.assert_allclose( - mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), - 1.0, rtol=1e-5) + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), 1.0, rtol=1e-5 + ) def test_diagonal_matrix_heavily_imbalanced(self): - A = jnp.diag(jnp.array([ - 1e-3, 1., 2., 3., 4., 10000.], dtype=self.dtype)) + A = jnp.diag( + jnp.array([1e-3, 1.0, 2.0, 3.0, 4.0, 10000.0], dtype=self.dtype) + ) v = jnp.ones((6, 1)).astype(self.dtype) Q, T = partial_lanczos.partial_lanczos( - lambda x: A @ x, v, jax.random.PRNGKey(9), 6) + lambda x: A @ x, v, jax.random.PRNGKey(9), 6 + ) atol = 1e-6 det_rtol = 1e-6 if self.dtype == np.float32: @@ -64,22 +74,26 @@ def test_diagonal_matrix_heavily_imbalanced(self): np.testing.assert_allclose(jnp.identity(6), Q[0] @ Q[0].T, atol=atol) np.testing.assert_allclose( mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), - 240., rtol=det_rtol) + 240.0, + rtol=det_rtol, + ) def test_partial_lanczos_full_lanczos(self): A = jnp.array([[1.0, 1.0], [1.0, 4.0]], dtype=self.dtype) v = jnp.array([[-1.0], [1.0]], dtype=self.dtype) Q, T = partial_lanczos.partial_lanczos( - lambda x: A @ x, v, jax.random.PRNGKey(3), 2) + lambda x: A @ x, v, jax.random.PRNGKey(3), 2 + ) np.testing.assert_allclose(jnp.identity(2), Q[0] @ Q[0].T, atol=1e-6) np.testing.assert_allclose( - mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), - 3.0, rtol=1e-5) + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), 3.0, rtol=1e-5 + ) def test_partial_lanczos_with_jit(self): def partial_lanczos_pure_tensor(A, v): return partial_lanczos.partial_lanczos( - lambda x: A @ x, v, jax.random.PRNGKey(4), 2) + lambda x: A @ x, v, jax.random.PRNGKey(4), 2 + ) partial_lanczos_jit = jax.jit(partial_lanczos_pure_tensor) @@ -88,35 +102,37 @@ def partial_lanczos_pure_tensor(A, v): Q, T = partial_lanczos_jit(A, v) np.testing.assert_allclose(jnp.identity(2), Q[0] @ Q[0].T, atol=1e-6) np.testing.assert_allclose( - mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), - 3.0, rtol=1e-5) + mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]), 3.0, rtol=1e-5 + ) def test_partial_lanczos_with_batching(self): v = jnp.zeros((10, 3), dtype=self.dtype) v = v.at[:, 0].set(jnp.ones(10)) v = v.at[:, 1].set(jnp.array([1, -1, 1, -1, 1, -1, 1, -1, 1, -1])) - v = v.at[:, 2].set(jnp.array( - [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])) + v = v.at[:, 2].set( + jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + ) Q, T = partial_lanczos.partial_lanczos( - lambda x: x, v, jax.random.PRNGKey(5), 10) + lambda x: x, v, jax.random.PRNGKey(5), 10 + ) self.assertEqual(Q.shape, (3, 10, 10)) self.assertEqual(T.diag.shape, (3, 10)) - np.testing.assert_allclose( - Q[0, 0, :], jnp.ones(10) / jnp.sqrt(10.0)) - np.testing.assert_allclose( - Q[1, 0, :], v[:, 1] / jnp.sqrt(10.0)) + np.testing.assert_allclose(Q[0, 0, :], jnp.ones(10) / jnp.sqrt(10.0)) + np.testing.assert_allclose(Q[1, 0, :], v[:, 1] / jnp.sqrt(10.0)) def test_make_lanczos_preconditioner(self): kernel = jnp.identity(10).astype(self.dtype) preconditioner = partial_lanczos.make_lanczos_preconditioner( - kernel, jax.random.PRNGKey(5)) + kernel, jax.random.PRNGKey(5) + ) log_det = preconditioner.log_abs_determinant() self.assertAlmostEqual(0.0, log_det, places=4) out = preconditioner.solve(jnp.identity(10)) np.testing.assert_allclose(out, jnp.identity(10), atol=9e-2) kernel = jnp.identity(100).astype(self.dtype) preconditioner = partial_lanczos.make_lanczos_preconditioner( - kernel, jax.random.PRNGKey(6)) + kernel, jax.random.PRNGKey(6) + ) # TODO(thomaswc): Investigate ways to improve the numerical stability # here so that log_det is closer to zero than this. log_det = preconditioner.log_abs_determinant() @@ -125,16 +141,18 @@ def test_make_lanczos_preconditioner(self): np.testing.assert_allclose(out, jnp.identity(100), atol=0.2) def test_preconditioner_preserves_psd(self): - M = jnp.array([[2.6452732, -1.4553788, -0.5272188, 0.524349], - [-1.4553788, 4.4274387, 0.21998158, 1.8666775], - [-0.5272188, 0.21998158, 2.4756536, -0.5257966], - [0.524349, 1.8666775, -0.5257966, 2.889879]]).astype( - self.dtype) + M = jnp.array([ + [2.6452732, -1.4553788, -0.5272188, 0.524349], + [-1.4553788, 4.4274387, 0.21998158, 1.8666775], + [-0.5272188, 0.21998158, 2.4756536, -0.5257966], + [0.524349, 1.8666775, -0.5257966, 2.889879], + ]).astype(self.dtype) orig_eigenvalues = jnp.linalg.eigvalsh(M) self.assertFalse((orig_eigenvalues < 0).any()) preconditioner = partial_lanczos.make_lanczos_preconditioner( - M, jax.random.PRNGKey(7)) + M, jax.random.PRNGKey(7) + ) preconditioned_M = preconditioner.solve(M) after_eigenvalues = jnp.linalg.eigvalsh(preconditioned_M) self.assertFalse((after_eigenvalues < 0).any()) @@ -142,9 +160,8 @@ def test_preconditioner_preserves_psd(self): def test_my_tridiagonal_solve(self): empty = jnp.array([]).astype(self.dtype) self.assertEqual( - 0, - partial_lanczos.my_tridiagonal_solve( - empty, empty, empty, empty).size) + 0, partial_lanczos.my_tridiagonal_solve(empty, empty, empty, empty).size + ) np.testing.assert_allclose( jnp.array([2.5]), @@ -152,7 +169,9 @@ def test_my_tridiagonal_solve(self): jnp.array([0.0], dtype=self.dtype), jnp.array([2.0], dtype=self.dtype), jnp.array([0.0], dtype=self.dtype), - jnp.array([5.0], dtype=self.dtype))) + jnp.array([5.0], dtype=self.dtype), + ), + ) np.testing.assert_allclose( jnp.array([-4.5, 3.5]), @@ -160,16 +179,19 @@ def test_my_tridiagonal_solve(self): jnp.array([0.0, 1.0], dtype=self.dtype), jnp.array([2.0, 3.0], dtype=self.dtype), jnp.array([4.0, 0.0], dtype=self.dtype), - jnp.array([5.0, 6.0], dtype=self.dtype))) + jnp.array([5.0, 6.0], dtype=self.dtype), + ), + ) np.testing.assert_allclose( - jnp.array([-33.0/2.0, 115.0/12.0, -11.0/6.0])[:, jnp.newaxis], + jnp.array([-33.0 / 2.0, 115.0 / 12.0, -11.0 / 6.0])[:, jnp.newaxis], partial_lanczos.my_tridiagonal_solve( jnp.array([0.0, 1.0, 2.0], dtype=self.dtype), jnp.array([3.0, 4.0, 5.0], dtype=self.dtype), jnp.array([6.0, 7.0, 0.0], dtype=self.dtype), - jnp.array([8.0, 9.0, 10.0], dtype=self.dtype)[:, jnp.newaxis]), - atol=1e-6 + jnp.array([8.0, 9.0, 10.0], dtype=self.dtype)[:, jnp.newaxis], + ), + atol=1e-6, ) def test_psd_solve_multishift(self): @@ -178,12 +200,12 @@ def test_psd_solve_multishift(self): lambda x: x, v[:, jnp.newaxis], jnp.array([0.0, 2.0, -1.0], dtype=self.dtype), - jax.random.PRNGKey(8)) + jax.random.PRNGKey(8), + ) np.testing.assert_allclose( solutions[:, 0, :], - [[1.0, 1.0, 1.0, 1.0], - [-1.0, -1.0, -1.0, -1.0], - [0.5, 0.5, 0.5, 0.5]]) + [[1.0, 1.0, 1.0, 1.0], [-1.0, -1.0, -1.0, -1.0], [0.5, 0.5, 0.5, 0.5]], + ) class PartialLanczosTestFloat32(_PartialLanczosTest): diff --git a/tensorflow_probability/python/experimental/fastgp/preconditioners.py b/tensorflow_probability/python/experimental/fastgp/preconditioners.py index d891cd9531..71ad8d0806 100644 --- a/tensorflow_probability/python/experimental/fastgp/preconditioners.py +++ b/tensorflow_probability/python/experimental/fastgp/preconditioners.py @@ -43,49 +43,49 @@ from jaxtyping import Float from tensorflow_probability.python.experimental.fastgp import linalg from tensorflow_probability.python.experimental.fastgp import linear_operator_sum -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.python.internal.backend import jax as tf2jax +from tensorflow_probability.substrates.jax.math import linalg as tfp_math -jtf = tfp.tf2jax # pylint: disable=invalid-name @jax.named_call -def promote_to_operator(M) -> jtf.linalg.LinearOperator: - if isinstance(M, jtf.linalg.LinearOperator): +def promote_to_operator(M) -> tf2jax.linalg.LinearOperator: + if isinstance(M, tf2jax.linalg.LinearOperator): return M - return jtf.linalg.LinearOperatorFullMatrix(M, is_non_singular=True) + return tf2jax.linalg.LinearOperatorFullMatrix(M, is_non_singular=True) def _diag_part(M) -> jax.Array: - if isinstance(M, jtf.linalg.LinearOperator): + if isinstance(M, tf2jax.linalg.LinearOperator): return M.diag_part() - return jtf.linalg.diag_part(M) + return tf2jax.linalg.diag_part(M) class Preconditioner: """Base class for preconditioners.""" - def __init__(self, M: jtf.linalg.LinearOperator): + def __init__(self, M: tf2jax.linalg.LinearOperator): self.M = M - def full_preconditioner(self) -> jtf.linalg.LinearOperator: + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: """Returns the preconditioner.""" raise NotImplementedError('Base classes must override full_preconditioner.') - def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: """Returns the combined action of M and the preconditioner.""" raise NotImplementedError( 'Base classes must override preconditioned_operator.') - def log_det(self) -> jtf.linalg.LinearOperator: + def log_det(self) -> tf2jax.linalg.LinearOperator: """The log absolute value of the determinant of the preconditioner.""" return self.full_preconditioner().log_abs_determinant() def trace_of_inverse_product(self, A: jax.Array) -> Float: """Returns tr( P^(-1) A ) for a n x n, non-batched A.""" result = self.full_preconditioner().solve(A) - if isinstance(result, jtf.linalg.LinearOperator): + if isinstance(result, tf2jax.linalg.LinearOperator): return result.trace() return jnp.trace(result) @@ -94,15 +94,15 @@ def trace_of_inverse_product(self, A: jax.Array) -> Float: class IdentityPreconditioner(Preconditioner): """The do-nothing preconditioner.""" - def __init__(self, M: jtf.linalg.LinearOperator, **unused_kwargs): + def __init__(self, M: tf2jax.linalg.LinearOperator, **unused_kwargs): n = M.shape[-1] - self.id = jtf.linalg.LinearOperatorIdentity(n, dtype=M.dtype) + self.id = tf2jax.linalg.LinearOperatorIdentity(n, dtype=M.dtype) super().__init__(M) - def full_preconditioner(self) -> jtf.linalg.LinearOperator: + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: return self.id - def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: return promote_to_operator(self.M) def log_det(self) -> Float: @@ -123,17 +123,17 @@ def tree_unflatten(cls, unused_aux_data, children): class DiagonalPreconditioner(Preconditioner): """The best diagonal preconditioner; aka the Jacobi preconditioner.""" - def __init__(self, M: jtf.linalg.LinearOperator, **unused_kwargs): + def __init__(self, M: tf2jax.linalg.LinearOperator, **unused_kwargs): self.d = jnp.maximum(_diag_part(M), 1e-6) super().__init__(M) - def full_preconditioner(self) -> jtf.linalg.LinearOperator: - return jtf.linalg.LinearOperatorDiag( + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: + return tf2jax.linalg.LinearOperatorDiag( self.d, is_non_singular=True, is_positive_definite=True ) - def preconditioned_operator(self) -> jtf.linalg.LinearOperator: - return jtf.linalg.LinearOperatorComposition( + def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: + return tf2jax.linalg.LinearOperatorComposition( [promote_to_operator(self.M), self.full_preconditioner().inverse()] ) @@ -157,7 +157,7 @@ class LowRankPreconditioner(Preconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, low_rank: jax.Array, residual_diag: jax.Array = None, ): @@ -177,17 +177,19 @@ def __init__( self.residual_diag = jnp.maximum(1e-6, self.residual_diag) - diag_op = jtf.linalg.LinearOperatorDiag( - self.residual_diag, is_non_singular=True, is_positive_definite=True) - self.pre = jtf.linalg.LinearOperatorLowRankUpdate( - diag_op, self.low_rank, is_positive_definite=True) + diag_op = tf2jax.linalg.LinearOperatorDiag( + self.residual_diag, is_non_singular=True, is_positive_definite=True + ) + self.pre = tf2jax.linalg.LinearOperatorLowRankUpdate( + diag_op, self.low_rank, is_positive_definite=True + ) super().__init__(M) - def full_preconditioner(self) -> jtf.linalg.LinearOperator: + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: return self.pre - def preconditioned_operator(self) -> jtf.linalg.LinearOperator: - return jtf.linalg.LinearOperatorComposition( + def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: + return tf2jax.linalg.LinearOperatorComposition( [promote_to_operator(self.M), self.pre.inverse()] ) @@ -212,7 +214,7 @@ class RankOnePreconditioner(LowRankPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, num_iters: int = 10, **unused_kwargs, @@ -229,13 +231,13 @@ class PartialCholeskyPreconditioner(LowRankPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, rank: int = 20, **unused_kwargs, ): n = M.shape[-1] rank = min(n, rank) - low_rank, _, residual_diag = tfp.math.low_rank_cholesky(M, rank) + low_rank, _, residual_diag = tfp_math.low_rank_cholesky(M, rank) super().__init__(M, low_rank, residual_diag) @@ -245,7 +247,7 @@ class PartialLanczosPreconditioner(LowRankPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, rank: int = 20, **unused_kwargs, @@ -263,7 +265,7 @@ class TruncatedSvdPreconditioner(LowRankPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, rank: int = 20, num_iters: int = 10, @@ -282,7 +284,7 @@ class TruncatedRandomizedSvdPreconditioner(LowRankPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, rank: int = 20, num_iters: int = 10, @@ -298,7 +300,7 @@ class LowRankPlusScalingPreconditioner(Preconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, low_rank: jax.Array, scaling: jax.Array, ): @@ -309,30 +311,35 @@ def __init__( f' ({M.shape[-1]}, r)' ) self.scaling = scaling - identity_op = jtf.linalg.LinearOperatorScaledIdentity( + identity_op = tf2jax.linalg.LinearOperatorScaledIdentity( num_rows=M.shape[-1], multiplier=self.scaling, is_non_singular=True, - is_positive_definite=True) - self.pre = jtf.linalg.LinearOperatorLowRankUpdate( + is_positive_definite=True, + ) + self.pre = tf2jax.linalg.LinearOperatorLowRankUpdate( identity_op, self.low_rank, is_positive_definite=True, is_self_adjoint=True, - is_non_singular=True) + is_non_singular=True, + ) super().__init__(M) - def full_preconditioner(self) -> jtf.linalg.LinearOperator: + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: return self.pre - def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: linop = promote_to_operator(self.M) - operator = linear_operator_sum.LinearOperatorSum( - [linop, - jtf.linalg.LinearOperatorScaledIdentity( - num_rows=self.M.shape[-1], - multiplier=self.scaling)]) - return jtf.linalg.LinearOperatorComposition([self.pre.inverse(), operator]) + operator = linear_operator_sum.LinearOperatorSum([ + linop, + tf2jax.linalg.LinearOperatorScaledIdentity( + num_rows=self.M.shape[-1], multiplier=self.scaling + ), + ]) + return tf2jax.linalg.LinearOperatorComposition( + [self.pre.inverse(), operator] + ) @classmethod def from_lowrank(cls, M, low_rank, scaling): @@ -356,14 +363,14 @@ class PartialCholeskyPlusScalingPreconditioner( def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, scaling: jax.Array, rank: int = 20, **unused_kwargs, ): n = M.shape[-1] rank = min(n, rank) - low_rank, _, _ = tfp.math.low_rank_cholesky(M, rank) + low_rank, _, _ = tfp_math.low_rank_cholesky(M, rank) super().__init__(M, low_rank, scaling) @@ -374,7 +381,7 @@ class PartialPivotedCholeskyPlusScalingPreconditioner( def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, scaling: jax.Array, rank: int = 20, **unused_kwargs, @@ -392,7 +399,7 @@ class PartialLanczosPlusScalingPreconditioner( def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, scaling: jax.Array, key: jax.Array, rank: int = 20, @@ -411,7 +418,7 @@ class TruncatedSvdPlusScalingPreconditioner(LowRankPlusScalingPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, scaling: jax.Array, key: jax.Array, rank: int = 20, @@ -432,7 +439,7 @@ class TruncatedRandomizedSvdPlusScalingPreconditioner( def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, scaling: jax.Array, key: jax.Array, rank: int = 20, @@ -447,28 +454,28 @@ class SplitPreconditioner(Preconditioner): """Base class for symmetric split preconditioners.""" # pylint: disable-next=useless-parent-delegation - def __init__(self, M: jtf.linalg.LinearOperator): + def __init__(self, M: tf2jax.linalg.LinearOperator): super().__init__(M) - def right_half(self) -> jtf.linalg.LinearOperator: + def right_half(self) -> tf2jax.linalg.LinearOperator: """Returns R, where the preconditioner is P = R^T R.""" raise NotImplementedError('Base classes must override right_half method.') - def full_preconditioner(self) -> jtf.linalg.LinearOperator: + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: """Returns P = R^T R, the preconditioner's approximation to M.""" rh = self.right_half() lh = rh.adjoint() - return jtf.linalg.LinearOperatorComposition( + return tf2jax.linalg.LinearOperatorComposition( [lh, rh], is_self_adjoint=True, is_positive_definite=True, ) - def preconditioned_operator(self) -> jtf.linalg.LinearOperator: + def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: """Returns R^(-T) M R^(-1).""" rhi = self.right_half().inverse() lhi = rhi.adjoint() - return jtf.linalg.LinearOperatorComposition( + return tf2jax.linalg.LinearOperatorComposition( [lhi, promote_to_operator(self.M), rhi], is_self_adjoint=True, is_positive_definite=True, @@ -488,18 +495,18 @@ def trace_of_inverse_product(self, A: jax.Array) -> Float: class DiagonalSplitPreconditioner(SplitPreconditioner): """The split conditioner which pre and post multiplies by a diagonal.""" - def __init__(self, M: jtf.linalg.LinearOperator, **unused_kwargs): + def __init__(self, M: tf2jax.linalg.LinearOperator, **unused_kwargs): self.d = jnp.maximum(_diag_part(M), 1e-6) self.sqrt_d = jnp.sqrt(self.d) super().__init__(M) - def right_half(self) -> jtf.linalg.LinearOperator: - return jtf.linalg.LinearOperatorDiag( + def right_half(self) -> tf2jax.linalg.LinearOperator: + return tf2jax.linalg.LinearOperatorDiag( self.sqrt_d, is_non_singular=True, is_positive_definite=True ) - def full_preconditioner(self) -> jtf.linalg.LinearOperator: - return jtf.linalg.LinearOperatorDiag( + def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: + return tf2jax.linalg.LinearOperatorDiag( self.d, is_non_singular=True, is_positive_definite=True ) @@ -523,7 +530,7 @@ class LowRankSplitPreconditioner(SplitPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, low_rank: jax.Array, residual_diag: jax.Array = None, ): @@ -556,14 +563,14 @@ def __init__( # Turn off Pyformat because it puts spaces between a slice bound and its # colon. # fmt: off - self.B = jtf.linalg.LinearOperatorFullMatrix( + self.B = tf2jax.linalg.LinearOperatorFullMatrix( self.low_rank[:self.r, :], is_non_singular=True ) - self.C = jtf.linalg.LinearOperatorFullMatrix(self.low_rank[self.r:, :]) + self.C = tf2jax.linalg.LinearOperatorFullMatrix(self.low_rank[self.r:, :]) sqrt_d = jnp.sqrt(self.residual_diag[self.r:]) # fmt: on - self.D = jtf.linalg.LinearOperatorDiag(sqrt_d, is_non_singular=True) - P = tfp.tf2jax.linalg.LinearOperatorBlockLowerTriangular( + self.D = tf2jax.linalg.LinearOperatorDiag(sqrt_d, is_non_singular=True) + P = tf2jax.linalg.LinearOperatorBlockLowerTriangular( [[self.B], [self.C, self.D]], is_non_singular=True ) # We started from M ~ low_rank low_rank^t (because low_rank is n by r), @@ -572,7 +579,7 @@ def __init__( super().__init__(M) - def right_half(self) -> jtf.linalg.LinearOperator: + def right_half(self) -> tf2jax.linalg.LinearOperator: return self.P def trace_of_inverse_product(self, A: jax.Array) -> Float: @@ -627,7 +634,7 @@ class RankOneSplitPreconditioner(LowRankSplitPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, num_iters: int = 10, **unused_kwargs, @@ -644,13 +651,13 @@ class PartialCholeskySplitPreconditioner(LowRankSplitPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, rank: int = 20, **unused_kwargs, ): n = M.shape[-1] rank = min(n, rank) - low_rank, _, residual_diag = tfp.math.low_rank_cholesky(M, rank) + low_rank, _, residual_diag = tfp_math.low_rank_cholesky(M, rank) super().__init__(M, low_rank, residual_diag) @@ -660,7 +667,7 @@ class PartialLanczosSplitPreconditioner(LowRankSplitPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, rank: int = 20, **unused_kwargs, @@ -678,7 +685,7 @@ class TruncatedSvdSplitPreconditioner(LowRankSplitPreconditioner): def __init__( self, - M: jtf.linalg.LinearOperator, + M: tf2jax.linalg.LinearOperator, key: jax.Array, rank: int = 20, num_iters: int = 10, @@ -716,7 +723,7 @@ def __init__( @jax.named_call def get_preconditioner( - preconditioner_name: str, M: jtf.linalg.LinearOperator, **kwargs + preconditioner_name: str, M: tf2jax.linalg.LinearOperator, **kwargs ) -> SplitPreconditioner: """Return the preconditioner of the given type for the given matrix.""" if preconditioner_name == 'auto': diff --git a/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py b/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py index a0e0c54bd4..21ed8844d2 100644 --- a/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py +++ b/tensorflow_probability/python/experimental/fastgp/preconditioners_test.py @@ -20,11 +20,9 @@ import jax.numpy as jnp import numpy as np from tensorflow_probability.python.experimental.fastgp import preconditioners -import tensorflow_probability.substrates.jax as tfp +from tensorflow_probability.python.internal.backend import jax as tf2jax from absl.testing import absltest -jtf = tfp.tf2jax - # pylint: disable=invalid-name @@ -679,7 +677,7 @@ def test_preconditioner_with_linop(self, preconditioner): A = jax.random.uniform(jax.random.PRNGKey(8), shape=(2, 2), minval=-1.0, maxval=1.0).astype(self.dtype) M = A.T @ A + 0.6 * jnp.eye(2).astype(self.dtype) - M = jtf.linalg.LinearOperatorFullMatrix(M) + M = tf2jax.linalg.LinearOperatorFullMatrix(M) # There are no errors. _ = preconditioners.get_preconditioner( preconditioner, M, key=jax.random.PRNGKey(9), rank=5) diff --git a/tensorflow_probability/python/experimental/fastgp/schur_complement.py b/tensorflow_probability/python/experimental/fastgp/schur_complement.py index 1e7e3990fa..ecedf381d4 100644 --- a/tensorflow_probability/python/experimental/fastgp/schur_complement.py +++ b/tensorflow_probability/python/experimental/fastgp/schur_complement.py @@ -17,12 +17,14 @@ import jax import jax.numpy as jnp from tensorflow_probability.python.experimental.fastgp import mbcg -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.substrates.jax.bijectors import softplus +from tensorflow_probability.substrates.jax.internal import distribution_util +from tensorflow_probability.substrates.jax.internal import dtype_util +from tensorflow_probability.substrates.jax.internal import nest_util +from tensorflow_probability.substrates.jax.internal import parameter_properties +from tensorflow_probability.substrates.jax.math.psd_kernels import positive_semidefinite_kernel from tensorflow_probability.substrates.jax.math.psd_kernels.internal import util -jtf = tfp.tf2jax -parameter_properties = tfp.internal.parameter_properties - __all__ = [ 'SchurComplement', @@ -39,15 +41,18 @@ def _compute_divisor_matrix( """Compute the modified kernel with respect to the fixed inputs.""" divisor_matrix = base_kernel.matrix(fixed_inputs, fixed_inputs) if diag_shift is not None: - broadcast_shape = tfp.internal.distribution_util.get_broadcast_shape( - divisor_matrix, diag_shift[..., jnp.newaxis, jnp.newaxis]) + broadcast_shape = distribution_util.get_broadcast_shape( + divisor_matrix, diag_shift[..., jnp.newaxis, jnp.newaxis] + ) divisor_matrix = jnp.broadcast_to(divisor_matrix, broadcast_shape) divisor_matrix = _add_diagonal_shift( divisor_matrix, diag_shift[..., jnp.newaxis]) return divisor_matrix -class SchurComplement(tfp.math.psd_kernels.AutoCompositeTensorPsdKernel): +class SchurComplement( + positive_semidefinite_kernel.AutoCompositeTensorPsdKernel +): """The fast SchurComplement kernel. See tfp.math.psd_kernels.SchurComplement for more details. @@ -93,17 +98,18 @@ def __init__(self, if jax.tree_util.treedef_is_leaf( jax.tree_util.tree_structure(base_kernel.feature_ndims)): - dtype = tfp.internal.dtype_util.common_dtype( + dtype = dtype_util.common_dtype( [base_kernel, fixed_inputs], - dtype_hint=tfp.internal.nest_util.broadcast_structure( - base_kernel.feature_ndims, jnp.float32)) + dtype_hint=nest_util.broadcast_structure( + base_kernel.feature_ndims, jnp.float32 + ), + ) else: # If the fixed inputs are not nested, we assume they are of the same # float dtype as the remaining parameters. - dtype = tfp.internal.dtype_util.common_dtype( - [base_kernel, - fixed_inputs, - diag_shift], jnp.float32) + dtype = dtype_util.common_dtype( + [base_kernel, fixed_inputs, diag_shift], jnp.float32 + ) self._base_kernel = base_kernel self._diag_shift = diag_shift @@ -215,11 +221,17 @@ def _parameter_properties(cls, dtype): base_kernel=parameter_properties.BatchedComponentProperties(), fixed_inputs=parameter_properties.ParameterProperties( event_ndims=lambda self: jax.tree_util.tree_map( # pylint: disable=g-long-lambda - lambda nd: nd + 1, self.base_kernel.feature_ndims)), + lambda nd: nd + 1, self.base_kernel.feature_ndims + ) + ), diag_shift=parameter_properties.ParameterProperties( default_constraining_bijector_fn=( - lambda: tfp.bijectors.Softplus( # pylint: disable=g-long-lambda - low=tfp.internal.dtype_util.eps(dtype))))) + lambda: softplus.Softplus( # pylint: disable=g-long-lambda + low=dtype_util.eps(dtype) + ) + ) + ), + ) def _divisor_matrix(self, fixed_inputs=None): fixed_inputs = self._fixed_inputs if fixed_inputs is None else fixed_inputs