2424from botorch .acquisition import analytic , monte_carlo , multi_objective
2525from botorch .acquisition .acquisition import AcquisitionFunction
2626from botorch .acquisition .fixed_feature import FixedFeatureAcquisitionFunction
27+ from botorch .acquisition .joint_entropy_search import qJointEntropySearch
2728from botorch .acquisition .knowledge_gradient import (
2829 _get_value_function ,
2930 qKnowledgeGradient ,
@@ -468,6 +469,89 @@ def gen_batch_initial_conditions(
468469 return batch_initial_conditions
469470
470471
472+ def gen_optimal_input_initial_conditions (
473+ acq_function : AcquisitionFunction ,
474+ bounds : Tensor ,
475+ q : int ,
476+ num_restarts : int ,
477+ raw_samples : int ,
478+ fixed_features : dict [int , float ] | None = None ,
479+ options : dict [str , bool | float | int ] | None = None ,
480+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
481+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
482+ ):
483+ device = bounds .device
484+ if not hasattr (acq_function , "optimal_inputs" ):
485+ raise AttributeError (
486+ "gen_optimal_input_initial_conditions can only be used with "
487+ "an AcquisitionFunction that has an optimal_inputs attribute."
488+ )
489+ frac_random : float = options .get ("frac_random" , 0.0 )
490+ if not 0 <= frac_random <= 1 :
491+ raise ValueError (
492+ f"frac_random must take on values in (0,1). Value: { frac_random } "
493+ )
494+
495+ batch_limit = options .get ("batch_limit" )
496+ num_optima = acq_function .optimal_inputs .shape [:- 1 ].numel ()
497+ suggestions = acq_function .optimal_inputs .reshape (num_optima , - 1 )
498+ X = torch .empty (0 , q , bounds .shape [1 ], dtype = bounds .dtype )
499+ num_random = round (raw_samples * frac_random )
500+ if num_random > 0 :
501+ X_rnd = sample_q_batches_from_polytope (
502+ n = num_random ,
503+ q = q ,
504+ bounds = bounds ,
505+ n_burnin = options .get ("n_burnin" , 10000 ),
506+ n_thinning = options .get ("n_thinning" , 32 ),
507+ equality_constraints = equality_constraints ,
508+ inequality_constraints = inequality_constraints ,
509+ )
510+ X = torch .cat ((X , X_rnd ))
511+
512+ if num_random < raw_samples :
513+ X_perturbed = sample_points_around_best (
514+ acq_function = acq_function ,
515+ n_discrete_points = q * (raw_samples - num_random ),
516+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
517+ bounds = bounds ,
518+ best_X = suggestions ,
519+ )
520+ X_perturbed = X_perturbed .view (
521+ raw_samples - num_random , q , bounds .shape [- 1 ]
522+ ).cpu ()
523+ X = torch .cat ((X , X_perturbed ))
524+
525+ if options .get ("sample_around_best" , False ):
526+ X_best = sample_points_around_best (
527+ acq_function = acq_function ,
528+ n_discrete_points = q * raw_samples ,
529+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
530+ bounds = bounds ,
531+ )
532+ X_best = X_best .view (raw_samples , q , bounds .shape [- 1 ]).cpu ()
533+ X = torch .cat ((X , X_best ))
534+
535+ with torch .no_grad ():
536+ if batch_limit is None :
537+ batch_limit = X .shape [0 ]
538+ # Evaluate the acquisition function on `X_rnd` using `batch_limit`
539+ # sized chunks.
540+ acq_vals = torch .cat (
541+ [
542+ acq_function (x_ .to (device = device )).cpu ()
543+ for x_ in X .split (split_size = batch_limit , dim = 0 )
544+ ],
545+ dim = 0 ,
546+ )
547+
548+ eta = options .get ("eta" , 2.0 )
549+ weights = torch .exp (eta * standardize (acq_vals ))
550+ idx = torch .multinomial (weights , num_restarts , replacement = True )
551+
552+ return X [idx ]
553+
554+
471555def gen_one_shot_kg_initial_conditions (
472556 acq_function : qKnowledgeGradient ,
473557 bounds : Tensor ,
@@ -1141,6 +1225,7 @@ def sample_points_around_best(
11411225 best_pct : float = 5.0 ,
11421226 subset_sigma : float = 1e-1 ,
11431227 prob_perturb : float | None = None ,
1228+ best_X : Tensor | None = None ,
11441229) -> Tensor | None :
11451230 r"""Find best points and sample nearby points.
11461231
@@ -1154,65 +1239,71 @@ def sample_points_around_best(
11541239 subset_sigma: The standard deviation of the additive gaussian
11551240 noise for perturbing a subset of dimensions of the best points.
11561241 prob_perturb: The probability of perturbing each dimension.
1242+ best_X: A provided set of best points to sample around. If None, the
1243+ set is instead inferred. Used for e.g. info-theoretic acquisition
1244+ functions, where the sampled optima serve as suggestions for
1245+ acquisition function optimization.
11571246
11581247 Returns:
11591248 An optional `n_discrete_points x d`-dim tensor containing the
11601249 sampled points. This is None if no baseline points are found.
11611250 """
1162- X = get_X_baseline (acq_function = acq_function )
1163- if X is None :
1164- return
1165- with torch .no_grad ():
1166- try :
1167- posterior = acq_function .model .posterior (X )
1168- except AttributeError :
1169- warnings .warn (
1170- "Failed to sample around previous best points." ,
1171- BotorchWarning ,
1172- stacklevel = 3 ,
1173- )
1251+ if best_X is None :
1252+ X = get_X_baseline (acq_function = acq_function )
1253+ if X is None :
11741254 return
1175- mean = posterior .mean
1176- while mean .ndim > 2 :
1177- # take average over batch dims
1178- mean = mean .mean (dim = 0 )
1179- try :
1180- f_pred = acq_function .objective (mean )
1181- # Some acquisition functions do not have an objective
1182- # and for some acquisition functions the objective is None
1183- except (AttributeError , TypeError ):
1184- f_pred = mean
1185- if hasattr (acq_function , "maximize" ):
1186- # make sure that the optimiztaion direction is set properly
1187- if not acq_function .maximize :
1188- f_pred = - f_pred
1189- try :
1190- # handle constraints for EHVI-based acquisition functions
1191- constraints = acq_function .constraints
1192- if constraints is not None :
1193- neg_violation = - torch .stack (
1194- [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1195- ).sum (dim = - 1 )
1196- feas = neg_violation == 0
1197- if feas .any ():
1198- f_pred [~ feas ] = float ("-inf" )
1199- else :
1200- # set objective equal to negative violation
1201- f_pred = neg_violation
1202- except AttributeError :
1203- pass
1204- if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1205- # multi-objective
1206- # find pareto set
1207- is_pareto = is_non_dominated (f_pred )
1208- best_X = X [is_pareto ]
1209- else :
1210- if f_pred .shape [- 1 ] == 1 :
1211- f_pred = f_pred .squeeze (- 1 )
1212- n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1213- # the view() is to ensure that best_idcs is not a scalar tensor
1214- best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1215- best_X = X [best_idcs ]
1255+ with torch .no_grad ():
1256+ try :
1257+ posterior = acq_function .model .posterior (X )
1258+ except AttributeError :
1259+ warnings .warn (
1260+ "Failed to sample around previous best points." ,
1261+ BotorchWarning ,
1262+ stacklevel = 3 ,
1263+ )
1264+ return
1265+ mean = posterior .mean
1266+ while mean .ndim > 2 :
1267+ # take average over batch dims
1268+ mean = mean .mean (dim = 0 )
1269+ try :
1270+ f_pred = acq_function .objective (mean )
1271+ # Some acquisition functions do not have an objective
1272+ # and for some acquisition functions the objective is None
1273+ except (AttributeError , TypeError ):
1274+ f_pred = mean
1275+ if hasattr (acq_function , "maximize" ):
1276+ # make sure that the optimiztaion direction is set properly
1277+ if not acq_function .maximize :
1278+ f_pred = - f_pred
1279+ try :
1280+ # handle constraints for EHVI-based acquisition functions
1281+ constraints = acq_function .constraints
1282+ if constraints is not None :
1283+ neg_violation = - torch .stack (
1284+ [c (mean ).clamp_min (0.0 ) for c in constraints ], dim = - 1
1285+ ).sum (dim = - 1 )
1286+ feas = neg_violation == 0
1287+ if feas .any ():
1288+ f_pred [~ feas ] = float ("-inf" )
1289+ else :
1290+ # set objective equal to negative violation
1291+ f_pred = neg_violation
1292+ except AttributeError :
1293+ pass
1294+ if f_pred .ndim == mean .ndim and f_pred .shape [- 1 ] > 1 :
1295+ # multi-objective
1296+ # find pareto set
1297+ is_pareto = is_non_dominated (f_pred )
1298+ best_X = X [is_pareto ]
1299+ else :
1300+ if f_pred .shape [- 1 ] == 1 :
1301+ f_pred = f_pred .squeeze (- 1 )
1302+ n_best = max (1 , round (X .shape [0 ] * best_pct / 100 ))
1303+ # the view() is to ensure that best_idcs is not a scalar tensor
1304+ best_idcs = torch .topk (f_pred , n_best ).indices .view (- 1 )
1305+ best_X = X [best_idcs ]
1306+
12161307 use_perturbed_sampling = best_X .shape [- 1 ] >= 20 or prob_perturb is not None
12171308 n_trunc_normal_points = (
12181309 n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points
@@ -1234,7 +1325,7 @@ def sample_points_around_best(
12341325 )
12351326 perturbed_X = torch .cat ([perturbed_X , perturbed_subset_dims_X ], dim = 0 )
12361327 # shuffle points
1237- perm = torch .randperm (perturbed_X .shape [0 ], device = X .device )
1328+ perm = torch .randperm (perturbed_X .shape [0 ], device = best_X .device )
12381329 perturbed_X = perturbed_X [perm ]
12391330 return perturbed_X
12401331
0 commit comments