Skip to content

Commit

Permalink
Add __init__.py file for fastgp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606326145
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Feb 12, 2024
1 parent b597b1c commit d3116d7
Show file tree
Hide file tree
Showing 13 changed files with 487 additions and 321 deletions.
74 changes: 64 additions & 10 deletions tensorflow_probability/python/experimental/fastgp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,36 @@ package(
],
)

py_library(
name = "fastgp.jax",
srcs = ["__init__.py"],
deps = [
":fast_gp",
":fast_gprm",
":fast_log_det",
":fast_mtgp",
":linalg",
":linear_operator_sum",
":mbcg",
":partial_lanczos",
":preconditioners",
":schur_complement",
"//tensorflow_probability/python/internal:all_util",
],
)

# Dummy libraries to satisfy the multi_substrate_py_library deps of
# tfp/python/experimental:experimental.
py_library(
name = "fastgp",
deps = [],
)

py_library(
name = "fastgp.numpy",
deps = [],
)

py_library(
name = "mbcg",
srcs = ["mbcg.py"],
Expand Down Expand Up @@ -55,7 +85,16 @@ py_library(
":mbcg",
":preconditioners",
# jax dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/distributions:distribution.jax",
"//tensorflow_probability/python/distributions:gaussian_process_regression_model.jax",
"//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax",
"//tensorflow_probability/python/internal:dtype_util.jax",
"//tensorflow_probability/python/internal:parameter_properties.jax",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:tensor_util.jax",
"//tensorflow_probability/python/internal/backend/jax",
"//tensorflow_probability/python/mcmc:sample_halton_sequence.jax",
],
)

Expand Down Expand Up @@ -83,7 +122,14 @@ py_library(
":mbcg",
":preconditioners",
# jax dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/distributions:distribution.jax",
"//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax",
"//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel.jax",
"//tensorflow_probability/python/internal:dtype_util.jax",
"//tensorflow_probability/python/internal:prefer_static.jax",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:tensor_util.jax",
"//tensorflow_probability/python/internal/backend/jax",
],
)

Expand All @@ -108,7 +154,11 @@ py_library(
":preconditioners",
":schur_complement",
# jax dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax",
"//tensorflow_probability/python/internal:dtype_util.jax",
"//tensorflow_probability/python/internal:nest_util.jax",
"//tensorflow_probability/python/internal:parameter_properties.jax",
],
)

Expand All @@ -133,7 +183,7 @@ py_library(
# jax dep,
# jax:experimental_sparse dep,
# jaxtyping dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/internal/backend/jax",
],
)

Expand All @@ -157,18 +207,20 @@ py_library(
":mbcg",
# jax dep,
# scipy dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/internal/backend/jax",
],
)

py_test(
name = "partial_lanczos_test",
srcs = ["partial_lanczos_test.py"],
deps = [
":mbcg",
":partial_lanczos",
# absl/testing:absltest dep,
# jax dep,
# numpy dep,
"//tensorflow_probability/substrates:jax",
],
)

Expand All @@ -183,7 +235,6 @@ py_library(
# jaxtyping dep,
# numpy dep,
# scipy dep,
"//tensorflow_probability/substrates:jax",
],
)

Expand All @@ -206,7 +257,6 @@ py_library(
name = "linear_operator_sum",
srcs = ["linear_operator_sum.py"],
deps = [
"//tensorflow_probability/substrates:jax",
],
)

Expand All @@ -219,7 +269,8 @@ py_library(
# jax dep,
# jax:experimental_sparse dep,
# jaxtyping dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/internal/backend/jax",
"//tensorflow_probability/python/math:linalg.jax",
],
)

Expand All @@ -232,7 +283,7 @@ py_test(
# absl/testing:absltest dep,
# absl/testing:parameterized dep,
# jax dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/internal/backend/jax",
],
)

Expand All @@ -242,7 +293,10 @@ py_library(
deps = [
":preconditioners",
# jax dep,
"//tensorflow_probability/substrates:jax",
"//tensorflow_probability/python/bijectors:softplus.jax",
"//tensorflow_probability/python/internal:distribution_util.jax",
"//tensorflow_probability/python/internal:dtype_util.jax",
"//tensorflow_probability/python/internal:nest_util.jax",
],
)

Expand Down
42 changes: 42 additions & 0 deletions tensorflow_probability/python/experimental/fastgp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Package for training Gaussian Processes in time less than O(n^3)."""

from tensorflow_probability.python.experimental.fastgp import fast_gp
from tensorflow_probability.python.experimental.fastgp import fast_gprm
from tensorflow_probability.python.experimental.fastgp import fast_log_det
from tensorflow_probability.python.experimental.fastgp import fast_mtgp
from tensorflow_probability.python.experimental.fastgp import linalg
from tensorflow_probability.python.experimental.fastgp import linear_operator_sum
from tensorflow_probability.python.experimental.fastgp import mbcg
from tensorflow_probability.python.experimental.fastgp import partial_lanczos
from tensorflow_probability.python.experimental.fastgp import preconditioners
from tensorflow_probability.python.experimental.fastgp import schur_complement
from tensorflow_probability.python.internal import all_util

_allowed_symbols = [
'fast_log_det',
'fast_gp',
'fast_gprm',
'fast_mtgp',
'linalg',
'linear_operator_sum',
'mbcg',
'partial_lanczos',
'preconditioners',
'schur_complement',
]

all_util.remove_undocumented(__name__, _allowed_symbols)
Loading

0 comments on commit d3116d7

Please sign in to comment.