2323from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
2424from botorch .exceptions import InputDataError , UnsupportedError
2525from botorch .exceptions .warnings import OptimizationWarning
26- from botorch .generation .gen import gen_candidates_scipy
26+ from botorch .generation .gen import gen_candidates_scipy , TGenCandidates
2727from botorch .logging import logger
2828from botorch .optim .initializers import (
2929 gen_batch_initial_conditions ,
@@ -64,6 +64,7 @@ def optimize_acqf(
6464 post_processing_func : Optional [Callable [[Tensor ], Tensor ]] = None ,
6565 batch_initial_conditions : Optional [Tensor ] = None ,
6666 return_best_only : bool = True ,
67+ gen_candidates : TGenCandidates = gen_candidates_scipy ,
6768 sequential : bool = False ,
6869 ** kwargs : Any ,
6970) -> Tuple [Tensor , Tensor ]:
@@ -103,6 +104,8 @@ def optimize_acqf(
103104 this if you do not want to use default initialization strategy.
104105 return_best_only: If False, outputs the solutions corresponding to all
105106 random restart initializations of the optimization.
107+ gen_candidates: A callable for generating candidates given initial
108+ conditions. Default: `gen_candidates_scipy`
106109 sequential: If False, uses joint optimization, otherwise uses sequential
107110 optimization.
108111 kwargs: Additonal keyword arguments.
@@ -273,7 +276,7 @@ def _optimize_batch_candidates(
273276 if timeout_sec is not None :
274277 timeout_sec = (timeout_sec - start_time ) / len (batched_ics )
275278
276- scipy_kws = {
279+ gen_kws = {
277280 "acquisition_function" : acq_function ,
278281 "lower_bounds" : None if bounds [0 ].isinf ().all () else bounds [0 ],
279282 "upper_bounds" : None if bounds [1 ].isinf ().all () else bounds [1 ],
@@ -289,8 +292,8 @@ def _optimize_batch_candidates(
289292 # optimize using random restart optimization
290293 with warnings .catch_warnings (record = True ) as ws :
291294 warnings .simplefilter ("always" , category = OptimizationWarning )
292- batch_candidates_curr , batch_acq_values_curr = gen_candidates_scipy (
293- initial_conditions = batched_ics_ , ** scipy_kws
295+ batch_candidates_curr , batch_acq_values_curr = gen_candidates (
296+ initial_conditions = batched_ics_ , ** gen_kwargs
294297 )
295298 opt_warnings += ws
296299 batch_candidates_list .append (batch_candidates_curr )
0 commit comments