|
36 | 36 | from vizier import algorithms as vza |
37 | 37 | from vizier import pyvizier as vz |
38 | 38 | from vizier._src.algorithms.designers import quasi_random |
39 | | -from vizier._src.algorithms.designers import scalarization |
40 | 39 | from vizier._src.algorithms.designers.gp import acquisitions as acq_lib |
41 | 40 | from vizier._src.algorithms.designers.gp import gp_models |
42 | 41 | from vizier._src.algorithms.designers.gp import output_warpers |
@@ -202,27 +201,18 @@ def __attrs_post_init__(self): |
202 | 201 | # Multi-objective overrides. |
203 | 202 | m_info = self._problem.metric_information |
204 | 203 | if not m_info.is_single_objective: |
205 | | - num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE)) |
206 | 204 |
|
207 | 205 | # Create scalarization weights. |
208 | 206 | self._rng, weights_rng = jax.random.split(self._rng) |
209 | | - weights = jax.random.normal( |
210 | | - weights_rng, shape=(self._num_scalarizations, num_obj) |
211 | | - ) |
212 | | - weights = jnp.abs(weights) |
213 | | - weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True) |
214 | 207 |
|
215 | 208 | def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction: |
216 | 209 | # Scalarized UCB. |
217 | | - labels_array = data.labels.padded_array |
218 | | - has_labels = labels_array.shape[0] > 0 |
219 | | - ref_point = ( |
220 | | - acq_lib.get_reference_point(data.labels, self._ref_scaling) |
221 | | - if has_labels |
222 | | - else None |
| 210 | + scalarizer = acq_lib.create_hv_scalarization( |
| 211 | + self._num_scalarizations, data.labels, weights_rng |
223 | 212 | ) |
224 | | - scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point) |
225 | 213 |
|
| 214 | + labels_array = data.labels.padded_array |
| 215 | + has_labels = labels_array.shape[0] > 0 |
226 | 216 | max_scalarized = None |
227 | 217 | if has_labels: |
228 | 218 | max_scalarized = jnp.max(scalarizer(labels_array), axis=-1) |
|
0 commit comments