Skip to content

Commit 5408373

Browse files
vizier-teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 712550496
1 parent e0d923e commit 5408373

File tree

4 files changed

+272
-61
lines changed

4 files changed

+272
-61
lines changed

vizier/_src/algorithms/designers/gp/acquisitions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,31 @@ def __call__(
557557
)()
558558

559559

560+
def create_hv_scalarization(
561+
num_scalarizations: int, labels: types.PaddedArray, rng: jax.Array
562+
):
563+
"""Creates a HyperVolumeScalarization with random weights.
564+
565+
Args:
566+
num_scalarizations: The number of scalarizations to create.
567+
labels: The labels used to create the reference point.
568+
rng: The random key to use for sampling the weights.
569+
570+
Returns:
571+
A HyperVolumeScalarization with random weights.
572+
"""
573+
weights = jax.random.normal(
574+
rng,
575+
shape=(num_scalarizations, labels.shape[1]),
576+
)
577+
weights = jnp.abs(weights)
578+
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
579+
ref_point = (
580+
get_reference_point(labels, scale=0.01) if labels.shape[0] > 0 else None
581+
)
582+
return scalarization.HyperVolumeScalarization(weights, ref_point)
583+
584+
560585
# TODO: What do we end up jitting? If we end up directly jitting this call
561586
# then we should make it `eqx.Module` and set
562587
# `reduction_fn=eqx.field(static=True)` instead.

vizier/_src/algorithms/designers/gp_bandit.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from vizier import algorithms as vza
3737
from vizier import pyvizier as vz
3838
from vizier._src.algorithms.designers import quasi_random
39-
from vizier._src.algorithms.designers import scalarization
4039
from vizier._src.algorithms.designers.gp import acquisitions as acq_lib
4140
from vizier._src.algorithms.designers.gp import gp_models
4241
from vizier._src.algorithms.designers.gp import output_warpers
@@ -202,27 +201,18 @@ def __attrs_post_init__(self):
202201
# Multi-objective overrides.
203202
m_info = self._problem.metric_information
204203
if not m_info.is_single_objective:
205-
num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE))
206204

207205
# Create scalarization weights.
208206
self._rng, weights_rng = jax.random.split(self._rng)
209-
weights = jax.random.normal(
210-
weights_rng, shape=(self._num_scalarizations, num_obj)
211-
)
212-
weights = jnp.abs(weights)
213-
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
214207

215208
def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
216209
# Scalarized UCB.
217-
labels_array = data.labels.padded_array
218-
has_labels = labels_array.shape[0] > 0
219-
ref_point = (
220-
acq_lib.get_reference_point(data.labels, self._ref_scaling)
221-
if has_labels
222-
else None
210+
scalarizer = acq_lib.create_hv_scalarization(
211+
self._num_scalarizations, data.labels, weights_rng
223212
)
224-
scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)
225213

214+
labels_array = data.labels.padded_array
215+
has_labels = labels_array.shape[0] > 0
226216
max_scalarized = None
227217
if has_labels:
228218
max_scalarized = jnp.max(scalarizer(labels_array), axis=-1)

0 commit comments

Comments
 (0)