2323from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
2424from botorch .exceptions import InputDataError , UnsupportedError
2525from botorch .exceptions .warnings import OptimizationWarning
26- from botorch .generation .gen import gen_candidates_scipy
26+ from botorch .generation .gen import gen_candidates_scipy , TGenCandidates
2727from botorch .logging import logger
2828from botorch .optim .initializers import (
2929 gen_batch_initial_conditions ,
3030 gen_one_shot_kg_initial_conditions ,
3131)
3232from botorch .optim .stopping import ExpMAStoppingCriterion
33+ from botorch .optim .utils import _filter_kwargs
3334from torch import Tensor
3435
3536INIT_OPTION_KEYS = {
@@ -64,6 +65,7 @@ def optimize_acqf(
6465 post_processing_func : Optional [Callable [[Tensor ], Tensor ]] = None ,
6566 batch_initial_conditions : Optional [Tensor ] = None ,
6667 return_best_only : bool = True ,
68+ gen_candidates : Optional [TGenCandidates ] = None ,
6769 sequential : bool = False ,
6870 ** kwargs : Any ,
6971) -> Tuple [Tensor , Tensor ]:
@@ -103,6 +105,12 @@ def optimize_acqf(
103105 this if you do not want to use default initialization strategy.
104106 return_best_only: If False, outputs the solutions corresponding to all
105107 random restart initializations of the optimization.
108+ gen_candidates: A callable for generating candidates (and their associated
109+ acquisition values) given a tensor of initial conditions and an
110+ acquisition function. Other common inputs include lower and upper bounds
111+ and a dictionary of options, but refer to the documentation of specific
112+ generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
113+ for method-specific inputs. Default: `gen_candidates_scipy`
106114 sequential: If False, uses joint optimization, otherwise uses sequential
107115 optimization.
108116 kwargs: Additonal keyword arguments.
@@ -134,6 +142,9 @@ def optimize_acqf(
134142 """
135143 start_time : float = time .monotonic ()
136144 timeout_sec = kwargs .pop ("timeout_sec" , None )
145+ # using a default of None simplifies unit testing
146+ if gen_candidates is None :
147+ gen_candidates = gen_candidates_scipy
137148
138149 if inequality_constraints is None :
139150 if not (bounds .ndim == 2 and bounds .shape [0 ] == 2 ):
@@ -229,6 +240,7 @@ def optimize_acqf(
229240 sequential = False ,
230241 ic_generator = ic_gen ,
231242 timeout_sec = timeout_sec ,
243+ gen_candidates = gen_candidates ,
232244 )
233245
234246 candidate_list .append (candidate )
@@ -277,6 +289,11 @@ def optimize_acqf(
277289 batch_limit : int = options .get (
278290 "batch_limit" , num_restarts if not nonlinear_inequality_constraints else 1
279291 )
292+ has_parameter_constraints = (
293+ inequality_constraints is not None
294+ or equality_constraints is not None
295+ or nonlinear_inequality_constraints is not None
296+ )
280297
281298 def _optimize_batch_candidates (
282299 timeout_sec : Optional [float ],
@@ -288,24 +305,36 @@ def _optimize_batch_candidates(
288305 if timeout_sec is not None :
289306 timeout_sec = (timeout_sec - start_time ) / len (batched_ics )
290307
291- scipy_kws = {
308+ gen_kwargs = {
292309 "acquisition_function" : acq_function ,
293310 "lower_bounds" : None if bounds [0 ].isinf ().all () else bounds [0 ],
294311 "upper_bounds" : None if bounds [1 ].isinf ().all () else bounds [1 ],
295312 "options" : {k : v for k , v in options .items () if k not in INIT_OPTION_KEYS },
296- "inequality_constraints" : inequality_constraints ,
297- "equality_constraints" : equality_constraints ,
298- "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
299313 "fixed_features" : fixed_features ,
300314 "timeout_sec" : timeout_sec ,
301315 }
302316
317+ if has_parameter_constraints :
318+ # only add parameter constraints to gen_kwargs if they are specified
319+ # to avoid unnecessary warnings in _filter_kwargs
320+ gen_kwargs .update (
321+ {
322+ "inequality_constraints" : inequality_constraints ,
323+ "equality_constraints" : equality_constraints ,
324+ # the line is too long
325+ "nonlinear_inequality_constraints" : (
326+ nonlinear_inequality_constraints
327+ ),
328+ }
329+ )
330+ filtered_gen_kwargs = _filter_kwargs (gen_candidates , ** gen_kwargs )
331+
303332 for i , batched_ics_ in enumerate (batched_ics ):
304333 # optimize using random restart optimization
305334 with warnings .catch_warnings (record = True ) as ws :
306335 warnings .simplefilter ("always" , category = OptimizationWarning )
307- batch_candidates_curr , batch_acq_values_curr = gen_candidates_scipy (
308- initial_conditions = batched_ics_ , ** scipy_kws
336+ batch_candidates_curr , batch_acq_values_curr = gen_candidates (
337+ initial_conditions = batched_ics_ , ** filtered_gen_kwargs
309338 )
310339 opt_warnings += ws
311340 batch_candidates_list .append (batch_candidates_curr )
0 commit comments