1919from absl import logging
2020import equinox as eqx
2121import jax
22- from jax import config
2322import numpy as np
2423from tensorflow_probability .substrates import jax as tfp
2524from vizier ._src .jax import stochastic_process_model as sp
2827from vizier .jax import optimizers
2928
3029from absl .testing import absltest
30+ from absl .testing import parameterized
3131
3232tfb = tfp .bijectors
3333
3434
35- class VizierGpTest (absltest .TestCase ):
35+ class VizierGpTest (parameterized .TestCase ):
3636
37- def _generate_xys (self ):
37+ def _generate_xys (self , num_metrics : int ):
3838 x_obs = np .array (
3939 [
4040 [
@@ -120,46 +120,58 @@ def _generate_xys(self):
120120 ],
121121 dtype = np .float64 ,
122122 )
123- y_obs = np .array (
124- [
125- 0.55552674 ,
126- - 0.29054829 ,
127- - 0.04703586 ,
128- 0.0217839 ,
129- 0.15445438 ,
130- 0.46654119 ,
131- 0.12255823 ,
132- - 0.19540335 ,
133- - 0.11772564 ,
134- - 0.44447326 ,
135- ],
136- dtype = np .float64 ,
137- )[:, np .newaxis ]
123+ y_obs = np .tile (
124+ np .array (
125+ [
126+ 0.55552674 ,
127+ - 0.29054829 ,
128+ - 0.04703586 ,
129+ 0.0217839 ,
130+ 0.15445438 ,
131+ 0.46654119 ,
132+ 0.12255823 ,
133+ - 0.19540335 ,
134+ - 0.11772564 ,
135+ - 0.44447326 ,
136+ ],
137+ dtype = np .float64 ,
138+ )[
139+ :, np .newaxis
140+ ], # Added a new axis to be compatible with `np.tile`.
141+ (1 , num_metrics ),
142+ )
138143 return x_obs , y_obs
139144
140145 # TODO: Define generic assertions for loss values/masking in
141146 # coroutines.
142- def test_masking_works (self ):
143- # Mask three dimensions and four observations.
144- x_obs , y_obs = self ._generate_xys ()
147+ @parameterized .parameters (
148+ # Pads two observations.
149+ dict (num_metrics = 1 , num_obs = 12 ),
150+ # No observations are padded because multimetric GP does not support
151+ # observation padding.
152+ dict (num_metrics = 2 , num_obs = 10 ),
153+ )
154+ def test_masking_works (self , num_metrics : int , num_obs : int ):
155+ x_obs , y_obs = self ._generate_xys (num_metrics )
145156 data = types .ModelData (
146157 features = types .ModelInput (
158+ # Pads three continuous dimensions.
147159 continuous = types .PaddedArray .from_array (
148- x_obs , target_shape = (12 , 9 ), fill_value = 1.0
160+ x_obs , target_shape = (num_obs , 9 ), fill_value = 1.0
149161 ),
150162 categorical = types .PaddedArray .from_array (
151163 np .zeros ((9 , 0 ), dtype = types .INT_DTYPE ),
152- target_shape = (12 , 2 ),
164+ target_shape = (num_obs , 2 ),
153165 fill_value = 1 ,
154166 ),
155167 ),
156168 labels = types .PaddedArray .from_array (
157- y_obs , target_shape = (12 , 1 ), fill_value = np .nan
169+ y_obs , target_shape = (num_obs , num_metrics ), fill_value = np .nan
158170 ),
159171 )
160172 model1 = sp .CoroutineWithData (
161173 tuned_gp_models .VizierGaussianProcess (
162- types .ContinuousAndCategorical [int ](9 , 2 )
174+ types .ContinuousAndCategorical [int ](9 , 2 ), num_metrics
163175 ),
164176 data = data ,
165177 )
@@ -173,7 +185,7 @@ def test_masking_works(self):
173185 )
174186 model2 = sp .CoroutineWithData (
175187 tuned_gp_models .VizierGaussianProcess (
176- types .ContinuousAndCategorical [int ](9 , 2 )
188+ types .ContinuousAndCategorical [int ](9 , 2 ), num_metrics
177189 ),
178190 data = modified_data ,
179191 )
@@ -205,37 +217,44 @@ def test_masking_works(self):
205217 model2 .loss_with_aux (optimal_params2 )[0 ],
206218 )
207219
208- def test_good_log_likelihood (self ):
220+ @parameterized .parameters (
221+ # Pads two observations.
222+ dict (num_metrics = 1 , num_obs = 12 ),
223+ # No observations are padded because multimetric GP does not support
224+ # observation padding.
225+ dict (num_metrics = 2 , num_obs = 10 ),
226+ )
227+ def test_good_log_likelihood (self , num_metrics : int , num_obs : int ):
209228 # We use a fixed random seed for sampling categorical data (and continuous
210229 # data from `_generate_xys`, above) so that the same data is used for every
211230 # test run.
212231 rng , init_rng , cat_rng = jax .random .split (jax .random .PRNGKey (2 ), 3 )
213- x_cont_obs , y_obs = self ._generate_xys ()
232+ x_cont_obs , y_obs = self ._generate_xys (num_metrics )
214233 data = types .ModelData (
215234 features = types .ModelInput (
216235 continuous = types .PaddedArray .from_array (
217- x_cont_obs , target_shape = (12 , 9 ), fill_value = np .nan
236+ x_cont_obs , target_shape = (num_obs , 9 ), fill_value = np .nan
218237 ),
219238 categorical = types .PaddedArray .from_array (
220239 jax .random .randint (
221240 cat_rng ,
222- shape = (12 , 3 ),
241+ shape = (num_obs , 3 ),
223242 minval = 0 ,
224243 maxval = 3 ,
225244 dtype = types .INT_DTYPE ,
226245 ),
227- target_shape = (12 , 5 ),
246+ target_shape = (num_obs , 5 ),
228247 fill_value = - 1 ,
229248 ),
230249 ),
231250 labels = types .PaddedArray .from_array (
232- y_obs , target_shape = (12 , 1 ), fill_value = np .nan
251+ y_obs , target_shape = (num_obs , num_metrics ), fill_value = np .nan
233252 ),
234253 )
235254 target_loss = - 0.2
236255 model = sp .CoroutineWithData (
237256 tuned_gp_models .VizierGaussianProcess (
238- types .ContinuousAndCategorical [int ](9 , 5 )
257+ types .ContinuousAndCategorical [int ](9 , 5 ), num_metrics
239258 ),
240259 data = data ,
241260 )
@@ -251,37 +270,53 @@ def test_good_log_likelihood(self):
251270 logging .info ('Loss: %s' , metrics ['loss' ])
252271 self .assertLess (np .min (metrics ['loss' ]), target_loss )
253272
254- def test_good_log_likelihood_linear (self ):
255- # We use a fixed random seed for sampling categorical data (and continuous
256- # data from `_generate_xys`, above) so that the same data is used for every
257- # test run.
273+ @parameterized .parameters (
274+ # Pads two observations.
275+ dict (num_metrics = 1 , num_obs = 12 ),
276+ # No observations are padded because multimetric GP does not support
277+ # observation padding.
278+ dict (num_metrics = 2 , num_obs = 10 ),
279+ )
280+ def test_good_log_likelihood_linear (self , num_metrics : int , num_obs : int ):
281+ """Tests that the GP with linear coef after ARD has good log likelihood.
282+
283+ The tests use a fixed random seed for sampling categorical data (and
284+ continuous data from `_generate_xys`, above) so that the same data is used
285+ for every test run.
286+
287+ Args:
288+ num_metrics: Number of metrics.
289+ num_obs: Number of observations.
290+ """
258291 rng , init_rng , cat_rng = jax .random .split (jax .random .PRNGKey (2 ), 3 )
259- x_cont_obs , y_obs = self ._generate_xys ()
292+ x_cont_obs , y_obs = self ._generate_xys (num_metrics )
260293 data = types .ModelData (
261294 features = types .ModelInput (
262295 continuous = types .PaddedArray .from_array (
263- x_cont_obs , target_shape = (12 , 9 ), fill_value = np .nan
296+ x_cont_obs , target_shape = (num_obs , 9 ), fill_value = np .nan
264297 ),
265298 categorical = types .PaddedArray .from_array (
266299 jax .random .randint (
267300 cat_rng ,
268- shape = (12 , 3 ),
301+ shape = (num_obs , 3 ),
269302 minval = 0 ,
270303 maxval = 3 ,
271304 dtype = types .INT_DTYPE ,
272305 ),
273- target_shape = (12 , 5 ),
306+ target_shape = (num_obs , 5 ),
274307 fill_value = - 1 ,
275308 ),
276309 ),
277310 labels = types .PaddedArray .from_array (
278- y_obs , target_shape = (12 , 1 ), fill_value = np .nan
311+ y_obs , target_shape = (num_obs , num_metrics ), fill_value = np .nan
279312 ),
280313 )
281314 target_loss = - 0.2
282315 model = sp .CoroutineWithData (
283316 tuned_gp_models .VizierGaussianProcess (
284- types .ContinuousAndCategorical [int ](9 , 5 ), _linear_coef = 1.0
317+ types .ContinuousAndCategorical [int ](9 , 5 ),
318+ num_metrics ,
319+ _linear_coef = 1.0 ,
285320 ),
286321 data = data ,
287322 )
@@ -327,11 +362,14 @@ def test_good_log_likelihood_linear(self):
327362 ),
328363 )
329364 y_pred_mean = predictive .predict (pred_features ).mean ()
330- self .assertEqual (y_pred_mean .shape , (best_n , n_pred_features ))
365+ self .assertEqual (
366+ y_pred_mean .shape ,
367+ (best_n , n_pred_features ) + ((num_metrics ,) if num_metrics > 1 else ()),
368+ )
331369
332370
333371if __name__ == '__main__' :
334372 # Jax disables float64 computations by default and will silently convert
335373 # float64s to float32s. We must explicitly enable float64.
336- config .update ('jax_enable_x64' , True )
374+ jax . config .update ('jax_enable_x64' , True )
337375 absltest .main ()
0 commit comments