diff --git a/tensorflow_probability/python/experimental/distributions/BUILD b/tensorflow_probability/python/experimental/distributions/BUILD index c6d3b45c87..4b98195ec0 100644 --- a/tensorflow_probability/python/experimental/distributions/BUILD +++ b/tensorflow_probability/python/experimental/distributions/BUILD @@ -205,12 +205,12 @@ multi_substrate_py_test( srcs = ["multitask_gaussian_process_test.py"], shard_count = 3, deps = [ + ":multitask_gaussian_process", + ":multitask_gaussian_process_regression_model", # absl/testing:parameterized dep, # numpy dep, # tensorflow dep, "//tensorflow_probability/python/distributions:gaussian_process", - "//tensorflow_probability/python/experimental/distributions:multitask_gaussian_process", - "//tensorflow_probability/python/experimental/distributions:multitask_gaussian_process_regression_model", "//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel", "//tensorflow_probability/python/internal:test_util", "//tensorflow_probability/python/math/psd_kernels:exponentiated_quadratic", diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py index 9609f63444..6d54186388 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py @@ -249,11 +249,11 @@ def __init__(self, parameters = dict(locals()) with tf.name_scope(name) as name: input_dtype = dtype_util.common_dtype( - dict( - kernel=kernel, - index_points=index_points), + dict(index_points=index_points), dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + kernel.feature_ndims, tf.float32 + ), + ) # If the input dtype is non-nested float, we infer a single dtype for the # input and the float parameters, which is also the dtype of the MTGP's