2020
2121import copy
2222import datetime
23+ import enum
2324import random
2425from typing import Any , Callable , Mapping , Optional , Sequence , Union
2526
3536from vizier import algorithms as vza
3637from vizier import pyvizier as vz
3738from vizier ._src .algorithms .designers import quasi_random
39+ from vizier ._src .algorithms .designers import scalarization
3840from vizier ._src .algorithms .designers .gp import acquisitions
3941from vizier ._src .algorithms .designers .gp import output_warpers
4042from vizier ._src .algorithms .optimizers import eagle_strategy as es
5153tfd = tfp .distributions
5254
5355
56+ class MultimetricPromisingRegionPenaltyType (enum .Enum ):
57+ """The type of penalty to apply to the points outside the promising region.
58+
59+ Configures the penalty term in `PEScoreFunction` for multimetric problems.
60+ """
61+
62+ # The penalty is applied to the points outside the union of the promising
63+ # regions of all metrics.
64+ UNION = 'union'
65+ # The penalty is applied to the points outside the intersection of the
66+ # promising regions of all metrics.
67+ INTERSECTION = 'intersection'
68+ # The penalty applied to a point in the search space is the average of
69+ # the penalties with respect to the promising regions of all metrics.
70+ AVERAGE = 'average'
71+
72+
5473class UCBPEConfig (eqx .Module ):
5574 """UCB-PE config parameters."""
5675
@@ -92,6 +111,13 @@ class UCBPEConfig(eqx.Module):
92111 optimize_set_acquisition_for_exploration : bool = eqx .field (
93112 default = False , static = True
94113 )
114+ # The type of penalty to apply to the points outside the promising region for
115+ # multimetric problems.
116+ multimetric_promising_region_penalty_type : (
117+ MultimetricPromisingRegionPenaltyType
118+ ) = eqx .field (
119+ default = MultimetricPromisingRegionPenaltyType .AVERAGE , static = True
120+ )
95121
96122 def __repr__ (self ):
97123 return eqx .tree_pformat (self , short_arrays = False )
@@ -155,10 +181,28 @@ def _compute_ucb_threshold(
155181 The predicted mean of the feature array with the maximum UCB among `xs`.
156182 """
157183 pred_mean = gprm .mean ()
158- ucb_values = jnp .where (
159- is_missing , - jnp .inf , pred_mean + ucb_coefficient * gprm .stddev ()
160- )
161- return pred_mean [jnp .argmax (ucb_values )]
184+ if pred_mean .ndim > 1 :
185+ # In the multimetric case, the predicted mean and stddev are of shape
186+ # [num_points, num_metrics].
187+ ucb_values = jnp .where (
188+ jnp .tile (is_missing [:, jnp .newaxis ], (1 , pred_mean .shape [- 1 ])),
189+ - jnp .inf ,
190+ pred_mean + ucb_coefficient * gprm .stddev (),
191+ )
192+ # The indices of the points with the maximum UCB values for each metric.
193+ best_ucb_indices = jnp .argmax (ucb_values , axis = 0 )
194+ return jax .vmap (
195+ lambda pred_mean , best_ucb_idx : pred_mean [best_ucb_idx ],
196+ in_axes = - 1 ,
197+ out_axes = - 1 ,
198+ )(pred_mean , best_ucb_indices )
199+ else :
200+ # In the single metric case, the predicted mean and stddev are of shape
201+ # [num_points].
202+ ucb_values = jnp .where (
203+ is_missing , - jnp .inf , pred_mean + ucb_coefficient * gprm .stddev ()
204+ )
205+ return pred_mean [jnp .argmax (ucb_values )]
162206
163207
164208# TODO: Use acquisitions.TrustRegion instead.
@@ -238,12 +282,45 @@ class UCBScoreFunction(eqx.Module):
238282 on completed and pending trials.
239283 ucb_coefficient: The UCB coefficient.
240284 trust_region: Trust region.
285+ scalarization_weights_rng: Random key for scalarization.
286+ labels: Labels, shaped as [num_index_points, num_metrics].
287+ num_scalarizations: Number of scalarizations.
241288 """
242289
243290 predictive : sp .UniformEnsemblePredictive
244291 predictive_all_features : sp .UniformEnsemblePredictive
245292 ucb_coefficient : jt .Float [jt .Array , '' ]
246293 trust_region : Optional [acquisitions .TrustRegion ]
294+ labels : types .PaddedArray
295+ scalarizer : scalarization .Scalarization
296+
297+ def __init__ (
298+ self ,
299+ predictive : sp .UniformEnsemblePredictive ,
300+ predictive_all_features : sp .UniformEnsemblePredictive ,
301+ ucb_coefficient : jt .Float [jt .Array , '' ],
302+ trust_region : Optional [acquisitions .TrustRegion ],
303+ scalarization_weights_rng : jax .Array ,
304+ labels : types .PaddedArray ,
305+ num_scalarizations : int = 1000 ,
306+ ):
307+ self .predictive = predictive
308+ self .predictive_all_features = predictive_all_features
309+ self .ucb_coefficient = ucb_coefficient
310+ self .trust_region = trust_region
311+ self .labels = labels
312+ weights = jax .random .normal (
313+ scalarization_weights_rng ,
314+ shape = (num_scalarizations , self .labels .shape [1 ]),
315+ )
316+ weights = jnp .abs (weights )
317+ weights = weights / jnp .linalg .norm (weights , axis = - 1 , keepdims = True )
318+ ref_point = (
319+ acquisitions .get_reference_point (self .labels , scale = 0.01 )
320+ if self .labels .shape [0 ] > 0
321+ else None
322+ )
323+ self .scalarizer = scalarization .HyperVolumeScalarization (weights , ref_point )
247324
248325 def score (
249326 self , xs : types .ModelInput , seed : Optional [jax .Array ] = None
@@ -264,9 +341,26 @@ def score_with_aux(
264341 mean = gprm .mean ()
265342 stddev_from_all = gprm_all_features .stddev ()
266343 acq_values = mean + self .ucb_coefficient * stddev_from_all
344+ # `self.labels` is of shape [num_index_points, num_metrics].
345+ if self .labels .shape [1 ] > 1 :
346+ scalarized = self .scalarizer (acq_values )
347+ padded_labels = self .labels .replace_fill_value (- np .inf ).padded_array
348+ if padded_labels .shape [0 ] > 0 :
349+ # Broadcast max_scalarized to the same shape as scalarized and take max.
350+ max_scalarized = jnp .max (self .scalarizer (padded_labels ), axis = - 1 )
351+ shape_mismatch = len (scalarized .shape ) - len (max_scalarized .shape )
352+ expand_max = jnp .expand_dims (
353+ max_scalarized , axis = range (- shape_mismatch , 0 )
354+ )
355+ scalarized = jnp .maximum (scalarized , expand_max )
356+ scalarized_acq_values = jnp .mean (scalarized , axis = 0 )
357+ else :
358+ scalarized_acq_values = acq_values
267359 if self .trust_region is not None :
268- acq_values = _apply_trust_region (self .trust_region , xs , acq_values )
269- return acq_values , {
360+ scalarized_acq_values = _apply_trust_region (
361+ self .trust_region , xs , scalarized_acq_values
362+ )
363+ return scalarized_acq_values , {
270364 'mean' : mean ,
271365 'stddev' : gprm .stddev (),
272366 'stddev_from_all' : stddev_from_all ,
@@ -303,6 +397,9 @@ class PEScoreFunction(eqx.Module):
303397 explore_ucb_coefficient : jt .Float [jt .Array , '' ]
304398 penalty_coefficient : jt .Float [jt .Array , '' ]
305399 trust_region : Optional [acquisitions .TrustRegion ]
400+ multimetric_promising_region_penalty_type : (
401+ MultimetricPromisingRegionPenaltyType
402+ )
306403
307404 def score (
308405 self , xs : types .ModelInput , seed : Optional [jax .Array ] = None
@@ -333,10 +430,34 @@ def score_with_aux(
333430
334431 gprm_all = self .predictive_all_features .predict (xs )
335432 stddev_from_all = gprm_all .stddev ()
336- acq_values = stddev_from_all + self .penalty_coefficient * jnp .minimum (
433+ penalty = self .penalty_coefficient * jnp .minimum (
337434 explore_ucb - threshold ,
338435 0.0 ,
339436 )
437+ # `stddev_from_all` and `penalty` are of shape
438+ # [num_index_points, num_metrics] for multi-metric problems or
439+ # [num_index_points] for single-metric problems.
440+ if stddev_from_all .ndim > 1 :
441+ if self .multimetric_promising_region_penalty_type == (
442+ MultimetricPromisingRegionPenaltyType .UNION
443+ ):
444+ scalarized_penalty = jnp .max (penalty , axis = - 1 )
445+ elif self .multimetric_promising_region_penalty_type == (
446+ MultimetricPromisingRegionPenaltyType .INTERSECTION
447+ ):
448+ scalarized_penalty = jnp .min (penalty , axis = - 1 )
449+ elif self .multimetric_promising_region_penalty_type == (
450+ MultimetricPromisingRegionPenaltyType .AVERAGE
451+ ):
452+ scalarized_penalty = jnp .mean (penalty , axis = - 1 )
453+ else :
454+ raise ValueError (
455+ 'Unsupported multimetric promising region penalty type:'
456+ f' { self .multimetric_promising_region_penalty_type } '
457+ )
458+ acq_values = jnp .mean (stddev_from_all , axis = - 1 ) + scalarized_penalty
459+ else :
460+ acq_values = stddev_from_all + penalty
340461 if self .trust_region is not None :
341462 acq_values = _apply_trust_region (self .trust_region , xs , acq_values )
342463 return acq_values , {
@@ -537,8 +658,14 @@ def __attrs_post_init__(self):
537658 # Extra validations
538659 if self ._problem .search_space .is_conditional :
539660 raise ValueError (f'{ type (self )} does not support conditional search.' )
540- elif len (self ._problem .metric_information ) != 1 :
541- raise ValueError (f'{ type (self )} works with exactly one metric.' )
661+ elif (
662+ len (self ._problem .metric_information ) != 1
663+ and self ._config .optimize_set_acquisition_for_exploration
664+ ):
665+ raise ValueError (
666+ f'{ type (self )} works with exactly one metric with'
667+ ' `optimize_set_acquisition_for_exploration` enabled.'
668+ )
542669
543670 # Extra initializations.
544671 # Discrete parameters are continuified to account for their actual values.
@@ -554,7 +681,7 @@ def __attrs_post_init__(self):
554681 self ._problem .search_space ,
555682 seed = int (jax .random .randint (qrs_seed , [], 0 , 2 ** 16 )),
556683 )
557- self ._output_warper = None
684+ self ._output_warpers : list [ output_warpers . OutputWarper ] = []
558685
559686 def update (
560687 self , completed : vza .CompletedTrials , all_active : vza .ActiveTrials
@@ -717,10 +844,15 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
717844 data .labels .shape ,
718845 _get_features_shape (data .features ),
719846 )
720- self ._output_warper = output_warpers .create_default_warper ()
721- warped_labels = self ._output_warper .warp (np .array (data .labels .unpad ()))
847+ unpadded_labels = np .asarray (data .labels .unpad ())
848+ warped_labels = []
849+ self ._output_warpers = []
850+ for i in range (data .labels .shape [1 ]):
851+ output_warper = output_warpers .create_default_warper ()
852+ warped_labels .append (output_warper .warp (unpadded_labels [:, i : i + 1 ]))
853+ self ._output_warpers .append (output_warper )
722854 labels = types .PaddedArray .from_array (
723- warped_labels ,
855+ np . concatenate ( warped_labels , axis = - 1 ) ,
724856 data .labels .padded_array .shape ,
725857 fill_value = data .labels .fill_value ,
726858 )
@@ -773,7 +905,10 @@ def _get_predictive_all_features(
773905 # Pending features are only used to predict standard deviation, so their
774906 # labels do not matter, and we simply set them to 0.
775907 dummy_labels = jnp .zeros (
776- shape = (pending_features .continuous .unpad ().shape [0 ], 1 ),
908+ shape = (
909+ pending_features .continuous .unpad ().shape [0 ],
910+ data .labels .shape [- 1 ],
911+ ),
777912 dtype = data .labels .padded_array .dtype ,
778913 )
779914 all_labels = jnp .concatenate ([data .labels .unpad (), dummy_labels ], axis = 0 )
@@ -840,11 +975,14 @@ def _suggest_one(
840975 # When `use_ucb` is true, the acquisition function computes the UCB
841976 # values. Otherwise, it computes the Pure-Exploration acquisition values.
842977 if use_ucb :
978+ scalarization_weights_rng , self ._rng = jax .random .split (self ._rng )
843979 scoring_fn = UCBScoreFunction (
844980 predictive ,
845981 predictive_all_features ,
846982 ucb_coefficient = self ._config .ucb_coefficient ,
847983 trust_region = tr if self ._use_trust_region else None ,
984+ scalarization_weights_rng = scalarization_weights_rng ,
985+ labels = data .labels ,
848986 )
849987 else :
850988 scoring_fn = PEScoreFunction (
@@ -854,6 +992,9 @@ def _suggest_one(
854992 ucb_coefficient = self ._config .ucb_coefficient ,
855993 explore_ucb_coefficient = self ._config .explore_region_ucb_coefficient ,
856994 trust_region = tr if self ._use_trust_region else None ,
995+ multimetric_promising_region_penalty_type = (
996+ self ._config .multimetric_promising_region_penalty_type
997+ ),
857998 )
858999
8591000 if isinstance (acquisition_optimizer , vb .VectorizedOptimizer ):
@@ -910,9 +1051,11 @@ def _suggest_one(
9101051 # debugging needs.
9111052 metadata = best_candidate .metadata .ns (self ._metadata_ns )
9121053 metadata .ns ('prediction_in_warped_y_space' ).update ({
913- 'mean' : f'{ predict_mean [0 ]} ' ,
914- 'stddev' : f'{ predict_stddev [0 ]} ' ,
915- 'stddev_from_all' : f'{ predict_stddev_from_all [0 ]} ' ,
1054+ 'mean' : np .array2string (np .asarray (predict_mean [0 ]), separator = ',' ),
1055+ 'stddev' : np .array2string (np .asarray (predict_stddev [0 ]), separator = ',' ),
1056+ 'stddev_from_all' : np .array2string (
1057+ np .asarray (predict_stddev_from_all [0 ]), separator = ','
1058+ ),
9161059 'acquisition' : f'{ acquisition } ' ,
9171060 'use_ucb' : f'{ use_ucb } ' ,
9181061 'trust_radius' : f'{ tr .trust_radius } ' ,
@@ -1060,20 +1203,36 @@ def sample(
10601203 )
10611204 samples = eqx .filter_jit (acquisitions .sample_from_predictive )(
10621205 predictive , xs , num_samples , key = rng
1063- ) # (num_samples, num_trials)
1064- # Scope the samples to non-padded only (there's a single padded dimension).
1206+ )
1207+ # Scope `samples` to non-padded only (there's a single padded dimension).
1208+ # `samples` has shape: [num_samples, num_trials] for single metric or
1209+ # [num_samples, num_trials, num_metrics] for multi-metric problems.
1210+ if samples .ndim == 2 :
1211+ samples = jnp .expand_dims (samples , axis = - 1 )
10651212 samples = samples [
1066- :, ~ (xs .continuous .is_missing [0 ] | xs .categorical .is_missing [0 ])
1213+ :, ~ (xs .continuous .is_missing [0 ] | xs .categorical .is_missing [0 ]), :
10671214 ]
10681215 # TODO: vectorize output warping.
1069- if self ._output_warper is not None :
1070- return np .vstack ([
1071- self ._output_warper .unwarp (samples [i ][..., np .newaxis ]).reshape (- 1 )
1072- for i in range (samples .shape [0 ])
1073- ])
1216+ if self ._output_warpers :
1217+ unwarped_samples = []
1218+ for metric_idx , output_warper in enumerate (self ._output_warpers ):
1219+ unwarped_samples .append (
1220+ np .vstack ([
1221+ output_warper .unwarp (
1222+ samples [i ][:, metric_idx : metric_idx + 1 ]
1223+ ).reshape (- 1 )
1224+ for i in range (samples .shape [0 ])
1225+ ])
1226+ )
1227+ unwarped_samples = np .stack (unwarped_samples , axis = - 1 )
1228+ if unwarped_samples .shape [- 1 ] > 1 :
1229+ return unwarped_samples
1230+ else :
1231+ return np .squeeze (unwarped_samples , axis = - 1 )
10741232 else :
10751233 raise TypeError (
1076- 'Output warper is expected to be set, but found to be None.'
1234+ 'Output warpers are expected to be set, but found to be'
1235+ f' { self ._output_warpers } .'
10771236 )
10781237
10791238 @profiler .record_runtime
0 commit comments