| 
16 | 16 | 
 
  | 
17 | 17 | """L-BFGS-B Strategy optimizer."""  | 
18 | 18 | 
 
  | 
19 |  | -from typing import Callable, Optional, Union  | 
 | 19 | +import itertools  | 
 | 20 | +from typing import Callable, Optional, Sequence, Union  | 
20 | 21 | 
 
  | 
21 | 22 | import attr  | 
22 | 23 | from flax import struct  | 
@@ -51,6 +52,9 @@ class LBFGSBOptimizer:  | 
51 | 52 |   n_feature_dimensions_with_padding: types.ContinuousAndCategorical[int] = (  | 
52 | 53 |       struct.field(pytree_node=False)  | 
53 | 54 |   )  | 
 | 55 | +  continuous_features_bounds: Sequence[tuple[float, float]] = struct.field(  | 
 | 56 | +      pytree_node=False  | 
 | 57 | +  )  | 
54 | 58 |   # Number of parallel runs of L-BFGS-B.  | 
55 | 59 |   random_restarts: int = struct.field(pytree_node=False, default=25)  | 
56 | 60 |   # Number of iterations for each L-BFGS-B run.  | 
@@ -144,10 +148,23 @@ def _opt_score_fn(x):  | 
144 | 148 |     def setup(rng):  | 
145 | 149 |       return jax.random.uniform(rng, shape=feature_shape)  | 
146 | 150 | 
 
  | 
147 |  | -    # Constraints are [0, 1].  | 
148 |  | -    constraints = sp.Constraint(  | 
149 |  | -        bounds=(np.zeros(feature_shape), np.ones(feature_shape))  | 
 | 151 | +    continuous_features_min, continuous_features_max = itertools.zip_longest(  | 
 | 152 | +        *self.continuous_features_bounds  | 
150 | 153 |     )  | 
 | 154 | +    bounds = []  | 
 | 155 | +    for bound, fill_value in (  | 
 | 156 | +        (continuous_features_min, 0.0),  | 
 | 157 | +        (continuous_features_max, 1.0),  | 
 | 158 | +    ):  | 
 | 159 | +      # Pad the bound with fill_value to the length of the feature dimension  | 
 | 160 | +      # with padding and account for the parallel dimension, so the shape of the  | 
 | 161 | +      # bound is the same as `feature_shape` above.  | 
 | 162 | +      padded_bound = bound + (fill_value,) * (  | 
 | 163 | +          self.n_feature_dimensions_with_padding.continuous - len(bound)  | 
 | 164 | +      )  | 
 | 165 | +      bounds.append(np.stack([padded_bound] * parallel_dim, axis=0))  | 
 | 166 | + | 
 | 167 | +    constraints = sp.Constraint(bounds=(bounds[0], bounds[1]))  | 
151 | 168 | 
 
  | 
152 | 169 |     new_features, _ = optimize(  | 
153 | 170 |         jax.vmap(setup)(jax.random.split(init_seed, self.random_restarts)),  | 
@@ -199,9 +216,15 @@ def __call__(  | 
199 | 216 |         empty_features.continuous.shape[-1],  | 
200 | 217 |         empty_features.categorical.shape[-1],  | 
201 | 218 |     )  | 
 | 219 | +    continous_features_bounds = [  | 
 | 220 | +        (float(spec.bounds[0]), float(spec.bounds[1]))  | 
 | 221 | +        for spec in converter.output_specs.continuous  | 
 | 222 | +    ]  | 
 | 223 | + | 
202 | 224 |     return LBFGSBOptimizer(  | 
203 | 225 |         n_feature_dimensions=n_feature_dimensions,  | 
204 | 226 |         n_feature_dimensions_with_padding=n_feature_dimensions_with_padding,  | 
 | 227 | +        continuous_features_bounds=continous_features_bounds,  | 
205 | 228 |         random_restarts=self.random_restarts,  | 
206 | 229 |         maxiter=self.maxiter,  | 
207 | 230 |     )  | 
0 commit comments