1616
1717"""Tests for gp_bandit."""
1818
19+ from typing import Callable
1920from unittest import mock
2021
2122import jax
@@ -47,6 +48,69 @@ def _build_mock_continuous_array_specs(n):
4748 return [continuous_spec ] * n
4849
4950
51+ def _setup_lambda_search (
52+ f : Callable [[float ], float ], num_trials : int = 100
53+ ) -> tuple [gp_bandit .VizierGPBandit , list [vz .Trial ], vz .ProblemStatement ]:
54+ """Sets up a GP designer and outputs completed studies for `f`.
55+
56+ Args:
57+ f: 1D objective to be optimized, i.e. f(x), where x is a scalar in [-5., 5.)
58+ num_trials: Number of mock "evaluated" trials to return.
59+
60+ Returns:
61+ A GP designer set up for the problem of optimizing the objective, without any
62+ data updated.
63+ Evaluated trials against `f`.
64+ """
65+ assert (
66+ num_trials > 0
67+ ), f'Must provide a positive number of trials. Got { num_trials } .'
68+
69+ search_space = vz .SearchSpace ()
70+ search_space .root .add_float_param ('x0' , - 5.0 , 5.0 )
71+ problem = vz .ProblemStatement (
72+ search_space = search_space ,
73+ metric_information = vz .MetricsConfig (
74+ metrics = [
75+ vz .MetricInformation ('obj' , goal = vz .ObjectiveMetricGoal .MAXIMIZE ),
76+ ]
77+ ),
78+ )
79+
80+ suggestions = quasi_random .QuasiRandomDesigner (
81+ problem .search_space , seed = 1
82+ ).suggest (num_trials )
83+
84+ obs_trials = []
85+ for idx , suggestion in enumerate (suggestions ):
86+ trial = suggestion .to_trial (idx )
87+ x = suggestions [idx ].parameters ['x0' ].value
88+ trial .complete (vz .Measurement (metrics = {'obj' : f (x )}))
89+ obs_trials .append (trial )
90+
91+ gp_designer = gp_bandit .VizierGPBandit (problem , ard_optimizer = ard_optimizer )
92+ return gp_designer , obs_trials , problem
93+
94+
95+ def _compute_mse (
96+ designer : gp_bandit .VizierGPBandit ,
97+ test_trials : list [vz .Trial ],
98+ y_test : list [float ],
99+ ) -> float :
100+ """Evaluate the designer's accuracy on the test set.
101+
102+ Args:
103+ designer: The GP bandit designer to predict from.
104+ test_trials: The trials of the test set
105+ y_test: The results of the test set
106+
107+ Returns:
108+ The MSE of `designer` on `test_trials` and `y_test`
109+ """
110+ preds = designer .predict (test_trials )
111+ return np .sum (np .square (preds .mean - y_test ))
112+
113+
50114class GoogleGpBanditTest (parameterized .TestCase ):
51115
52116 @parameterized .parameters (
@@ -216,32 +280,8 @@ def test_on_flat_mixed_space(
216280 self .assertFalse (np .isnan (prediction .stddev ).any ())
217281
218282 def test_prediction_accuracy (self ):
219- search_space = vz .SearchSpace ()
220- search_space .root .add_float_param ('x0' , - 5.0 , 5.0 )
221- problem = vz .ProblemStatement (
222- search_space = search_space ,
223- metric_information = vz .MetricsConfig (
224- metrics = [
225- vz .MetricInformation (
226- 'obj' , goal = vz .ObjectiveMetricGoal .MAXIMIZE
227- ),
228- ]
229- ),
230- )
231283 f = lambda x : - ((x - 0.5 ) ** 2 )
232-
233- suggestions = quasi_random .QuasiRandomDesigner (
234- problem .search_space , seed = 1
235- ).suggest (100 )
236-
237- obs_trials = []
238- for idx , suggestion in enumerate (suggestions ):
239- trial = suggestion .to_trial (idx )
240- x = suggestions [idx ].parameters ['x0' ].value
241- trial .complete (vz .Measurement (metrics = {'obj' : f (x )}))
242- obs_trials .append (trial )
243-
244- gp_designer = gp_bandit .VizierGPBandit (problem , ard_optimizer = ard_optimizer )
284+ gp_designer , obs_trials , _ = _setup_lambda_search (f )
245285 gp_designer .update (vza .CompletedTrials (obs_trials ), vza .ActiveTrials ())
246286 pred_trial = vz .Trial ({'x0' : 0.0 })
247287 pred = gp_designer .predict ([pred_trial ])
@@ -261,6 +301,7 @@ def test_jit_once(self, *args):
261301 name = 'metric' , goal = vz .ObjectiveMetricGoal .MAXIMIZE
262302 )
263303 )
304+
264305 def create_designer (problem ):
265306 return gp_bandit .VizierGPBandit (
266307 problem = problem ,
@@ -299,6 +340,83 @@ def create_runner(problem):
299340 create_runner (problem ).run_designer (designer2 )
300341
301342
343+ class GPBanditPriorsTest (parameterized .TestCase ):
344+
345+ def test_prior_warping (self ):
346+ """Tests linear transform of objective has no impact on transfer learning."""
347+ f = lambda x : - ((x - 0.5 ) ** 2 )
348+ transform_f = lambda x : - 3 * ((x - 0.5 ) ** 2 ) + 10
349+
350+ # X is in range of what is defined in `_setup_lambda_search`, [-5.0, 5.0)
351+ x_test = np .random .default_rng (1 ).uniform (- 5.0 , 5.0 , 100 )
352+ y_test = [transform_f (x ) for x in x_test ]
353+ test_trials = [vz .Trial ({'x0' : x }) for x in x_test ]
354+
355+ # Create the designer with a prior and the trials to train the prior.
356+ gp_designer_with_prior , obs_trials_for_prior , _ = _setup_lambda_search (
357+ f = f , num_trials = 100
358+ )
359+
360+ # Set priors to above trials.
361+ gp_designer_with_prior .set_priors (
362+ [vza .CompletedTrials (obs_trials_for_prior )]
363+ )
364+
365+ # Create a no prior designer on the transformed function `transform_f`.
366+ # Also use the generated trials to update both the designer with prior and
367+ # the designer without. This tests that the prior designer is resilient
368+ # to linear transforms between the prior and the top level study.
369+ gp_designer_no_prior , obs_trials , _ = _setup_lambda_search (
370+ f = transform_f , num_trials = 20
371+ )
372+
373+ # Update both designers with the actual study.
374+ gp_designer_no_prior .update (
375+ vza .CompletedTrials (obs_trials ), vza .ActiveTrials ()
376+ )
377+ gp_designer_with_prior .update (
378+ vza .CompletedTrials (obs_trials ), vza .ActiveTrials ()
379+ )
380+
381+ # Evaluate the no prior designer's accuracy on the test set.
382+ mse_no_prior = _compute_mse (gp_designer_no_prior , test_trials , y_test )
383+
384+ # Evaluate the designer with prior's accuracy on the test set.
385+ mse_with_prior = _compute_mse (gp_designer_with_prior , test_trials , y_test )
386+
387+ # The designer with a prior should predict better.
388+ self .assertLess (mse_with_prior , mse_no_prior )
389+
390+ @parameterized .parameters (
391+ dict (iters = 3 , batch_size = 5 ),
392+ dict (iters = 5 , batch_size = 1 ),
393+ )
394+ def test_run_with_priors (self , * , iters , batch_size ):
395+ f = lambda x : - ((x - 0.5 ) ** 2 )
396+
397+ # Create the designer with a prior and the trials to train the prior.
398+ gp_designer_with_prior , obs_trials_for_prior , problem = (
399+ _setup_lambda_search (f = f , num_trials = 100 )
400+ )
401+
402+ # Set priors to the above trials.
403+ gp_designer_with_prior .set_priors (
404+ [vza .CompletedTrials (obs_trials_for_prior )]
405+ )
406+
407+ self .assertLen (
408+ test_runners .RandomMetricsRunner (
409+ problem ,
410+ iters = iters ,
411+ batch_size = batch_size ,
412+ verbose = 1 ,
413+ validate_parameters = True ,
414+ seed = 1 ,
415+ ).run_designer (gp_designer_with_prior ),
416+ iters * batch_size ,
417+ )
418+
419+
302420if __name__ == '__main__' :
303421 # Jax disables float64 computations by default and will silently convert
304422 # float64s to float32s. We must explicitly enable float64.
0 commit comments