-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgp_layer.py
152 lines (126 loc) · 6.64 KB
/
gp_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#
# DKAFT
#
# Copyright (c) Siemens AG, 2021
# Authors:
# Zhiliang Wu <[email protected]>
# License-Identifier: MIT
import warnings
import torch
import gpytorch
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.utils.warnings import GPInputWarning
from gpytorch.models.exact_prediction_strategies import prediction_strategy
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
class SVGPLayer(ApproximateGP):
"""The SVGP output layer with an RBF kernel."""
def __init__(self, inducing_points):
"""
Args:
inducing_points (torch.Tensor): The initial inducing points.
"""
variational_distribution = CholeskyVariationalDistribution(
inducing_points.size(0))
variational_strategy = VariationalStrategy(self, inducing_points,
variational_distribution,
learn_inducing_locations=True
)
super(SVGPLayer, self).__init__(variational_strategy)
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class ExactGPModel(gpytorch.models.ExactGP):
"""The ExactGP output layer with an RBF kernel."""
def __init__(self, train_x, train_y, feature_extracter, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(RBFKernel())
self.feature_extractor = feature_extracter
self.training_size = train_y.size()[0]
def forward(self, *inputs):
x = inputs
projected_x = self.feature_extractor(x)
mean_x = self.mean_module(projected_x)
covar_x = self.covar_module(projected_x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
def __call__(self, *args, **kwargs):
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]
# Training mode: optimizing
if self.training:
if self.train_inputs is None:
raise RuntimeError(
"train_inputs, train_targets cannot be None in training mode. "
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
if settings.debug.on():
if not all(torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)):
raise RuntimeError("You must train on the training inputs!")
res = super(gpytorch.models.ExactGP, self).__call__(*inputs, **kwargs)
return res
# Prior mode
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
full_inputs = args
full_output = super(gpytorch.models.ExactGP, self).__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
return full_output
# Posterior mode
else:
if settings.debug.on():
if all(torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs)):
warnings.warn(
"The input matches the stored training data. Did you forget to call model.train()?",
GPInputWarning,
)
# Get the terms that only depend on training data
if self.prediction_strategy is None:
train_output = super(gpytorch.models.ExactGP, self).__call__(*train_inputs, **kwargs)
# Create the prediction strategy for
self.prediction_strategy = prediction_strategy(
train_inputs=train_inputs,
train_prior_dist=train_output,
train_labels=self.train_targets,
likelihood=self.likelihood,
)
# Concatenate the input to the training input
full_inputs = []
batch_shape = train_inputs[0].shape[:-2]
for train_input, input in zip(train_inputs, inputs):
# Make sure the batch shapes agree for training/test data
# if batch_shape != train_input.shape[:-2]:
# batch_shape = _mul_broadcast_shape(batch_shape, train_input.shape[:-2])
# train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
# if batch_shape != input.shape[:-2]:
# batch_shape = _mul_broadcast_shape(batch_shape, input.shape[:-2])
# train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
# input = input.expand(*batch_shape, *input.shape[-2:])
full_inputs.append(torch.cat([train_input, input], dim=0))
# Get the joint distribution for training/test data
full_output = super(gpytorch.models.ExactGP, self).__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
# Determine the shape of the joint distribution
batch_shape = full_output.batch_shape
joint_shape = full_output.event_shape
tasks_shape = joint_shape[1:] # For multitask learning
test_shape = torch.Size([joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape])
# Make the prediction
with settings._use_eval_tolerance():
predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(full_mean, full_covar)
# Reshape predictive mean to match the appropriate event shape
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
return full_output.__class__(predictive_mean, predictive_covar)
if __name__ == '__main__':
pass