Skip to content

Commit e780b9c

Browse files
vizier-teamcopybara-github
authored andcommitted
Respect continuous feature bounds in LBFGSBOptimizer
`LBFGSBOptimizer` now uses the actual continuous feature bounds from the converter instead of assuming `[0, 1]`. This includes handling singleton bounds. PiperOrigin-RevId: 817327653
1 parent 0eb59ed commit e780b9c

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

vizier/_src/algorithms/optimizers/lbfgsb_optimizer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
"""L-BFGS-B Strategy optimizer."""
1818

19-
from typing import Callable, Optional, Union
19+
import itertools
20+
from typing import Callable, Optional, Sequence, Union
2021

2122
import attr
2223
from flax import struct
@@ -51,6 +52,9 @@ class LBFGSBOptimizer:
5152
n_feature_dimensions_with_padding: types.ContinuousAndCategorical[int] = (
5253
struct.field(pytree_node=False)
5354
)
55+
continuous_features_bounds: Sequence[tuple[float, float]] = struct.field(
56+
pytree_node=False
57+
)
5458
# Number of parallel runs of L-BFGS-B.
5559
random_restarts: int = struct.field(pytree_node=False, default=25)
5660
# Number of iterations for each L-BFGS-B run.
@@ -144,10 +148,23 @@ def _opt_score_fn(x):
144148
def setup(rng):
145149
return jax.random.uniform(rng, shape=feature_shape)
146150

147-
# Constraints are [0, 1].
148-
constraints = sp.Constraint(
149-
bounds=(np.zeros(feature_shape), np.ones(feature_shape))
151+
continuous_features_min, continuous_features_max = itertools.zip_longest(
152+
*self.continuous_features_bounds
150153
)
154+
bounds = []
155+
for bound, fill_value in (
156+
(continuous_features_min, 0.0),
157+
(continuous_features_max, 1.0),
158+
):
159+
# Pad the bound with fill_value to the length of the feature dimension
160+
# with padding and account for the parallel dimension, so the shape of the
161+
# bound is the same as `feature_shape` above.
162+
padded_bound = bound + (fill_value,) * (
163+
self.n_feature_dimensions_with_padding.continuous - len(bound)
164+
)
165+
bounds.append(np.stack([padded_bound] * parallel_dim, axis=0))
166+
167+
constraints = sp.Constraint(bounds=(bounds[0], bounds[1]))
151168

152169
new_features, _ = optimize(
153170
jax.vmap(setup)(jax.random.split(init_seed, self.random_restarts)),
@@ -199,9 +216,15 @@ def __call__(
199216
empty_features.continuous.shape[-1],
200217
empty_features.categorical.shape[-1],
201218
)
219+
continous_features_bounds = [
220+
(float(spec.bounds[0]), float(spec.bounds[1]))
221+
for spec in converter.output_specs.continuous
222+
]
223+
202224
return LBFGSBOptimizer(
203225
n_feature_dimensions=n_feature_dimensions,
204226
n_feature_dimensions_with_padding=n_feature_dimensions_with_padding,
227+
continuous_features_bounds=continous_features_bounds,
205228
random_restarts=self.random_restarts,
206229
maxiter=self.maxiter,
207230
)

vizier/_src/algorithms/optimizers/lbfgsb_optimizer_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ def test_optimize_candidates_len(self):
4242
res = optimizer(score_fn=score_fn)
4343
self.assertLen(res.rewards, 1)
4444

45+
def test_singleton_constraints_are_respected(self):
46+
problem = vz.ProblemStatement()
47+
problem.search_space.root.add_float_param('f1', 0.0, 10.0)
48+
problem.search_space.root.add_float_param('f2', 0.0, 10.0)
49+
problem.search_space.root.add_float_param('f3', 5.0, 5.0)
50+
converter = converters.TrialToModelInputConverter.from_problem(problem)
51+
score_fn = lambda x, _: jnp.sum(x.continuous.padded_array, axis=-1)
52+
optimizer = lo.LBFGSBOptimizerFactory(random_restarts=10, maxiter=20)(
53+
converter
54+
)
55+
res = optimizer(score_fn=score_fn)
56+
best_candidates = vb.best_candidates_to_trials(res, converter)
57+
best_score = score_fn(converter.to_features(best_candidates), None)
58+
# Evaluating the score function being optimized on the best candidate should
59+
# result in the same value as the output of the optimizer.
60+
self.assertSequenceAlmostEqual(best_score, res.rewards)
61+
4562
def test_best_candidates_count_is_1(self):
4663
problem = vz.ProblemStatement()
4764
problem.search_space.root.add_float_param('f1', 0.0, 1.0)

0 commit comments

Comments
 (0)