Skip to content

Commit 33e91e1

Browse files
vizier-teamcopybara-github
authored andcommitted
Supports multimetrics in VizierGaussianProcess
PiperOrigin-RevId: 709075734
1 parent 0429117 commit 33e91e1

File tree

4 files changed

+125
-73
lines changed

4 files changed

+125
-73
lines changed

vizier/_src/algorithms/designers/gp/gp_models.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from vizier._src.algorithms.designers.gp import transfer_learning as vtl
2929
from vizier._src.jax import stochastic_process_model as sp
3030
from vizier._src.jax import types
31-
from vizier._src.jax.models import multitask_tuned_gp_models
3231
from vizier._src.jax.models import tuned_gp_models
3332
from vizier.jax import optimizers
3433
from vizier.utils import profiler
@@ -155,19 +154,9 @@ def get_vizier_gp_coroutine(
155154
Returns:
156155
The model coroutine.
157156
"""
158-
# Construct the multi-task GP.
159-
if data.labels.shape[-1] > 1:
160-
gp_coroutine = multitask_tuned_gp_models.VizierMultitaskGaussianProcess(
161-
_feature_dim=types.ContinuousAndCategorical[int](
162-
data.features.continuous.padded_array.shape[-1],
163-
data.features.categorical.padded_array.shape[-1],
164-
),
165-
_num_tasks=data.labels.shape[-1],
166-
)
167-
return sp.StochasticProcessModel(gp_coroutine).coroutine
168-
169157
return tuned_gp_models.VizierGaussianProcess.build_model(
170-
data.features, linear_coef=linear_coef
158+
data,
159+
linear_coef=linear_coef,
171160
).coroutine
172161

173162

vizier/_src/algorithms/designers/gp_ucb_pe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,10 @@ def _build_gp_model_and_optimize_parameters(
619619
`data.labels`. If `data.features` is empty, the returned parameters are
620620
initial values picked by the GP model.
621621
"""
622+
# TODO: Creates a new abstract base class for GP models with a
623+
# `build_model` API to avoid disabling the pytype attribute-error.
622624
coroutine = self._gp_model_class.build_model( # pytype: disable=attribute-error
623-
data.features
625+
data
624626
).coroutine
625627
model = sp.CoroutineWithData(coroutine, data)
626628

vizier/_src/jax/models/tuned_gp_models.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
tfb = tfp.bijectors
3535
tfd = tfp.distributions
36+
tfde = tfp.experimental.distributions
3637
tfpk = tfp.math.psd_kernels
3738
tfpke = tfp.experimental.psd_kernels
3839

@@ -86,26 +87,34 @@ class VizierGaussianProcess(sp.ModelCoroutine[tfd.GaussianProcess]):
8687
"""
8788

8889
_dim: types.ContinuousAndCategorical[int] = struct.field(pytree_node=False)
90+
_num_metrics: int = struct.field(pytree_node=False)
8991
_use_retrying_cholesky: bool = struct.field(
9092
pytree_node=False, default=True, kw_only=True
9193
)
9294
_boundary_epsilon: float = struct.field(default=1e-12, kw_only=True)
9395
_linear_coef: Optional[float] = struct.field(default=None, kw_only=True)
9496

97+
def __attrs_post_init__(self):
98+
if self._num_metrics < 1:
99+
raise ValueError(
100+
'Number of metrics must be at least 1, got: {self._num_metrics}'
101+
)
102+
95103
@classmethod
96104
def build_model(
97105
cls,
98-
features: types.ModelInput,
106+
data: types.ModelData,
99107
*,
100108
use_retrying_cholesky: bool = True,
101109
linear_coef: Optional[float] = None,
102110
) -> sp.StochasticProcessModel:
103-
"""Returns the model and loss function."""
111+
"""Returns a StochasticProcessModel for the GP."""
104112
gp_coroutine = VizierGaussianProcess(
105113
_dim=types.ContinuousAndCategorical[int](
106-
features.continuous.padded_array.shape[-1],
107-
features.categorical.padded_array.shape[-1],
114+
data.features.continuous.padded_array.shape[-1],
115+
data.features.categorical.padded_array.shape[-1],
108116
),
117+
_num_metrics=data.labels.shape[-1],
109118
_use_retrying_cholesky=use_retrying_cholesky,
110119
_linear_coef=linear_coef,
111120
)
@@ -122,7 +131,9 @@ def __call__(
122131
continuous_feature_dim), (num_examples, categorical_feature_dim).
123132
124133
Yields:
125-
GaussianProcess whose event shape is `num_examples`.
134+
GaussianProcess whose event shape is `num_examples` for single-metric GP
135+
and MultiTaskGaussianProcess with event shape
136+
`[num_examples, num_metrics]` for multimetric GP.
126137
"""
127138
eps = self._boundary_epsilon
128139
observation_noise_bounds = (np.float64(1e-10 - eps), 1.0 + eps)
@@ -214,8 +225,11 @@ def __call__(
214225
# output a shape of `[batch_shape, 1]`, ensuring that batch dimensions
215226
# line up properly.
216227
mean_fn_constant = yield sp.ModelParameter(
217-
init_fn=lambda k: jax.random.normal(key=k, shape=[1]),
218-
regularizer=lambda x: 0.5 * jnp.squeeze(x, axis=-1) ** 2,
228+
init_fn=lambda k: jax.random.normal(
229+
key=k,
230+
shape=[1] if self._num_metrics == 1 else [1, self._num_metrics],
231+
),
232+
regularizer=lambda x: 0.5 * jnp.sum(x**2),
219233
name='mean_fn',
220234
)
221235

@@ -256,10 +270,19 @@ def __call__(
256270
)
257271
cholesky_fn = lambda matrix: retrying_cholesky(matrix)[0]
258272

259-
return tfd.GaussianProcess(
260-
kernel,
261-
index_points=inputs,
262-
observation_noise_variance=observation_noise_variance,
263-
cholesky_fn=cholesky_fn,
264-
mean_fn=mean_fn,
265-
)
273+
if self._num_metrics > 1:
274+
return tfde.MultiTaskGaussianProcess(
275+
tfpke.Independent(self._num_metrics, kernel),
276+
index_points=inputs,
277+
observation_noise_variance=observation_noise_variance,
278+
cholesky_fn=cholesky_fn,
279+
mean_fn=mean_fn,
280+
)
281+
else:
282+
return tfd.GaussianProcess(
283+
kernel,
284+
index_points=inputs,
285+
observation_noise_variance=observation_noise_variance,
286+
cholesky_fn=cholesky_fn,
287+
mean_fn=mean_fn,
288+
)

vizier/_src/jax/models/tuned_gp_models_test.py

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from absl import logging
2020
import equinox as eqx
2121
import jax
22-
from jax import config
2322
import numpy as np
2423
from tensorflow_probability.substrates import jax as tfp
2524
from vizier._src.jax import stochastic_process_model as sp
@@ -28,13 +27,14 @@
2827
from vizier.jax import optimizers
2928

3029
from absl.testing import absltest
30+
from absl.testing import parameterized
3131

3232
tfb = tfp.bijectors
3333

3434

35-
class VizierGpTest(absltest.TestCase):
35+
class VizierGpTest(parameterized.TestCase):
3636

37-
def _generate_xys(self):
37+
def _generate_xys(self, num_metrics: int):
3838
x_obs = np.array(
3939
[
4040
[
@@ -120,46 +120,58 @@ def _generate_xys(self):
120120
],
121121
dtype=np.float64,
122122
)
123-
y_obs = np.array(
124-
[
125-
0.55552674,
126-
-0.29054829,
127-
-0.04703586,
128-
0.0217839,
129-
0.15445438,
130-
0.46654119,
131-
0.12255823,
132-
-0.19540335,
133-
-0.11772564,
134-
-0.44447326,
135-
],
136-
dtype=np.float64,
137-
)[:, np.newaxis]
123+
y_obs = np.tile(
124+
np.array(
125+
[
126+
0.55552674,
127+
-0.29054829,
128+
-0.04703586,
129+
0.0217839,
130+
0.15445438,
131+
0.46654119,
132+
0.12255823,
133+
-0.19540335,
134+
-0.11772564,
135+
-0.44447326,
136+
],
137+
dtype=np.float64,
138+
)[
139+
:, np.newaxis
140+
], # Added a new axis to be compatible with `np.tile`.
141+
(1, num_metrics),
142+
)
138143
return x_obs, y_obs
139144

140145
# TODO: Define generic assertions for loss values/masking in
141146
# coroutines.
142-
def test_masking_works(self):
143-
# Mask three dimensions and four observations.
144-
x_obs, y_obs = self._generate_xys()
147+
@parameterized.parameters(
148+
# Pads two observations.
149+
dict(num_metrics=1, num_obs=12),
150+
# No observations are padded because multimetric GP does not support
151+
# observation padding.
152+
dict(num_metrics=2, num_obs=10),
153+
)
154+
def test_masking_works(self, num_metrics: int, num_obs: int):
155+
x_obs, y_obs = self._generate_xys(num_metrics)
145156
data = types.ModelData(
146157
features=types.ModelInput(
158+
# Pads three continuous dimensions.
147159
continuous=types.PaddedArray.from_array(
148-
x_obs, target_shape=(12, 9), fill_value=1.0
160+
x_obs, target_shape=(num_obs, 9), fill_value=1.0
149161
),
150162
categorical=types.PaddedArray.from_array(
151163
np.zeros((9, 0), dtype=types.INT_DTYPE),
152-
target_shape=(12, 2),
164+
target_shape=(num_obs, 2),
153165
fill_value=1,
154166
),
155167
),
156168
labels=types.PaddedArray.from_array(
157-
y_obs, target_shape=(12, 1), fill_value=np.nan
169+
y_obs, target_shape=(num_obs, num_metrics), fill_value=np.nan
158170
),
159171
)
160172
model1 = sp.CoroutineWithData(
161173
tuned_gp_models.VizierGaussianProcess(
162-
types.ContinuousAndCategorical[int](9, 2)
174+
types.ContinuousAndCategorical[int](9, 2), num_metrics
163175
),
164176
data=data,
165177
)
@@ -173,7 +185,7 @@ def test_masking_works(self):
173185
)
174186
model2 = sp.CoroutineWithData(
175187
tuned_gp_models.VizierGaussianProcess(
176-
types.ContinuousAndCategorical[int](9, 2)
188+
types.ContinuousAndCategorical[int](9, 2), num_metrics
177189
),
178190
data=modified_data,
179191
)
@@ -205,37 +217,44 @@ def test_masking_works(self):
205217
model2.loss_with_aux(optimal_params2)[0],
206218
)
207219

208-
def test_good_log_likelihood(self):
220+
@parameterized.parameters(
221+
# Pads two observations.
222+
dict(num_metrics=1, num_obs=12),
223+
# No observations are padded because multimetric GP does not support
224+
# observation padding.
225+
dict(num_metrics=2, num_obs=10),
226+
)
227+
def test_good_log_likelihood(self, num_metrics: int, num_obs: int):
209228
# We use a fixed random seed for sampling categorical data (and continuous
210229
# data from `_generate_xys`, above) so that the same data is used for every
211230
# test run.
212231
rng, init_rng, cat_rng = jax.random.split(jax.random.PRNGKey(2), 3)
213-
x_cont_obs, y_obs = self._generate_xys()
232+
x_cont_obs, y_obs = self._generate_xys(num_metrics)
214233
data = types.ModelData(
215234
features=types.ModelInput(
216235
continuous=types.PaddedArray.from_array(
217-
x_cont_obs, target_shape=(12, 9), fill_value=np.nan
236+
x_cont_obs, target_shape=(num_obs, 9), fill_value=np.nan
218237
),
219238
categorical=types.PaddedArray.from_array(
220239
jax.random.randint(
221240
cat_rng,
222-
shape=(12, 3),
241+
shape=(num_obs, 3),
223242
minval=0,
224243
maxval=3,
225244
dtype=types.INT_DTYPE,
226245
),
227-
target_shape=(12, 5),
246+
target_shape=(num_obs, 5),
228247
fill_value=-1,
229248
),
230249
),
231250
labels=types.PaddedArray.from_array(
232-
y_obs, target_shape=(12, 1), fill_value=np.nan
251+
y_obs, target_shape=(num_obs, num_metrics), fill_value=np.nan
233252
),
234253
)
235254
target_loss = -0.2
236255
model = sp.CoroutineWithData(
237256
tuned_gp_models.VizierGaussianProcess(
238-
types.ContinuousAndCategorical[int](9, 5)
257+
types.ContinuousAndCategorical[int](9, 5), num_metrics
239258
),
240259
data=data,
241260
)
@@ -251,37 +270,53 @@ def test_good_log_likelihood(self):
251270
logging.info('Loss: %s', metrics['loss'])
252271
self.assertLess(np.min(metrics['loss']), target_loss)
253272

254-
def test_good_log_likelihood_linear(self):
255-
# We use a fixed random seed for sampling categorical data (and continuous
256-
# data from `_generate_xys`, above) so that the same data is used for every
257-
# test run.
273+
@parameterized.parameters(
274+
# Pads two observations.
275+
dict(num_metrics=1, num_obs=12),
276+
# No observations are padded because multimetric GP does not support
277+
# observation padding.
278+
dict(num_metrics=2, num_obs=10),
279+
)
280+
def test_good_log_likelihood_linear(self, num_metrics: int, num_obs: int):
281+
"""Tests that the GP with linear coef after ARD has good log likelihood.
282+
283+
The tests use a fixed random seed for sampling categorical data (and
284+
continuous data from `_generate_xys`, above) so that the same data is used
285+
for every test run.
286+
287+
Args:
288+
num_metrics: Number of metrics.
289+
num_obs: Number of observations.
290+
"""
258291
rng, init_rng, cat_rng = jax.random.split(jax.random.PRNGKey(2), 3)
259-
x_cont_obs, y_obs = self._generate_xys()
292+
x_cont_obs, y_obs = self._generate_xys(num_metrics)
260293
data = types.ModelData(
261294
features=types.ModelInput(
262295
continuous=types.PaddedArray.from_array(
263-
x_cont_obs, target_shape=(12, 9), fill_value=np.nan
296+
x_cont_obs, target_shape=(num_obs, 9), fill_value=np.nan
264297
),
265298
categorical=types.PaddedArray.from_array(
266299
jax.random.randint(
267300
cat_rng,
268-
shape=(12, 3),
301+
shape=(num_obs, 3),
269302
minval=0,
270303
maxval=3,
271304
dtype=types.INT_DTYPE,
272305
),
273-
target_shape=(12, 5),
306+
target_shape=(num_obs, 5),
274307
fill_value=-1,
275308
),
276309
),
277310
labels=types.PaddedArray.from_array(
278-
y_obs, target_shape=(12, 1), fill_value=np.nan
311+
y_obs, target_shape=(num_obs, num_metrics), fill_value=np.nan
279312
),
280313
)
281314
target_loss = -0.2
282315
model = sp.CoroutineWithData(
283316
tuned_gp_models.VizierGaussianProcess(
284-
types.ContinuousAndCategorical[int](9, 5), _linear_coef=1.0
317+
types.ContinuousAndCategorical[int](9, 5),
318+
num_metrics,
319+
_linear_coef=1.0,
285320
),
286321
data=data,
287322
)
@@ -327,11 +362,14 @@ def test_good_log_likelihood_linear(self):
327362
),
328363
)
329364
y_pred_mean = predictive.predict(pred_features).mean()
330-
self.assertEqual(y_pred_mean.shape, (best_n, n_pred_features))
365+
self.assertEqual(
366+
y_pred_mean.shape,
367+
(best_n, n_pred_features) + ((num_metrics,) if num_metrics > 1 else ()),
368+
)
331369

332370

333371
if __name__ == '__main__':
334372
# Jax disables float64 computations by default and will silently convert
335373
# float64s to float32s. We must explicitly enable float64.
336-
config.update('jax_enable_x64', True)
374+
jax.config.update('jax_enable_x64', True)
337375
absltest.main()

0 commit comments

Comments
 (0)