diff --git a/core/src/main/java/ai/timefold/solver/core/enterprise/TimefoldSolverEnterpriseService.java b/core/src/main/java/ai/timefold/solver/core/enterprise/TimefoldSolverEnterpriseService.java index 22ebeca27b..1378722578 100644 --- a/core/src/main/java/ai/timefold/solver/core/enterprise/TimefoldSolverEnterpriseService.java +++ b/core/src/main/java/ai/timefold/solver/core/enterprise/TimefoldSolverEnterpriseService.java @@ -29,6 +29,7 @@ import ai.timefold.solver.core.impl.localsearch.decider.LocalSearchDecider; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.Acceptor; import ai.timefold.solver.core.impl.localsearch.decider.forager.LocalSearchForager; +import ai.timefold.solver.core.impl.localsearch.decider.restart.RestartStrategy; import ai.timefold.solver.core.impl.partitionedsearch.PartitionedSearchPhase; import ai.timefold.solver.core.impl.solver.termination.Termination; @@ -100,7 +101,8 @@ ConstructionHeuristicDecider buildConstructionHeuristic(T ConstructionHeuristicForager forager, HeuristicConfigPolicy configPolicy); LocalSearchDecider buildLocalSearch(int moveThreadCount, Termination termination, - MoveSelector moveSelector, Acceptor acceptor, LocalSearchForager forager, + MoveSelector moveSelector, RestartStrategy restartStrategy, + Acceptor acceptor, LocalSearchForager forager, EnvironmentMode environmentMode, HeuristicConfigPolicy configPolicy); PartitionedSearchPhase buildPartitionedSearch(int phaseIndex, diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/DefaultLocalSearchPhaseFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/DefaultLocalSearchPhaseFactory.java index 22bd5d9986..f0a98b2c38 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/DefaultLocalSearchPhaseFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/DefaultLocalSearchPhaseFactory.java @@ -32,6 +32,7 @@ import ai.timefold.solver.core.impl.localsearch.decider.acceptor.AcceptorFactory; import ai.timefold.solver.core.impl.localsearch.decider.forager.LocalSearchForager; import ai.timefold.solver.core.impl.localsearch.decider.forager.LocalSearchForagerFactory; +import ai.timefold.solver.core.impl.localsearch.decider.restart.RestoreBestSolutionRestartStrategy; import ai.timefold.solver.core.impl.phase.AbstractPhaseFactory; import ai.timefold.solver.core.impl.solver.recaller.BestSolutionRecaller; import ai.timefold.solver.core.impl.solver.termination.Termination; @@ -65,6 +66,7 @@ private LocalSearchDecider buildDecider(HeuristicConfigPolicy termination) { var moveSelector = buildMoveSelector(configPolicy); var acceptor = buildAcceptor(configPolicy); + var restartStrategy = new RestoreBestSolutionRestartStrategy(); var forager = buildForager(configPolicy); if (moveSelector.isNeverEnding() && !forager.supportsNeverEndingMoveSelector()) { throw new IllegalStateException("The moveSelector (" + moveSelector @@ -77,11 +79,12 @@ private LocalSearchDecider buildDecider(HeuristicConfigPolicy decider; if (moveThreadCount == null) { - decider = new LocalSearchDecider<>(configPolicy.getLogIndentation(), termination, moveSelector, acceptor, forager); + decider = new LocalSearchDecider<>(configPolicy.getLogIndentation(), termination, moveSelector, + restartStrategy, acceptor, forager); } else { decider = TimefoldSolverEnterpriseService.loadOrFail(TimefoldSolverEnterpriseService.Feature.MULTITHREADED_SOLVING) - .buildLocalSearch(moveThreadCount, termination, moveSelector, acceptor, forager, environmentMode, - configPolicy); + .buildLocalSearch(moveThreadCount, termination, moveSelector, restartStrategy, acceptor, forager, + environmentMode, configPolicy); } if (environmentMode.isNonIntrusiveFullAsserted()) { decider.setAssertMoveScoreFromScratch(true); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/LocalSearchDecider.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/LocalSearchDecider.java index bcdb810ec2..306556a797 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/LocalSearchDecider.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/LocalSearchDecider.java @@ -6,6 +6,7 @@ import ai.timefold.solver.core.impl.heuristic.selector.move.MoveSelector; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.Acceptor; import ai.timefold.solver.core.impl.localsearch.decider.forager.LocalSearchForager; +import ai.timefold.solver.core.impl.localsearch.decider.restart.RestartStrategy; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; @@ -29,6 +30,7 @@ public class LocalSearchDecider { protected final String logIndentation; protected final Termination termination; protected final MoveSelector moveSelector; + protected final RestartStrategy restartStrategy; protected final Acceptor acceptor; protected final LocalSearchForager forager; @@ -36,10 +38,12 @@ public class LocalSearchDecider { protected boolean assertExpectedUndoMoveScore = false; public LocalSearchDecider(String logIndentation, Termination termination, - MoveSelector moveSelector, Acceptor acceptor, LocalSearchForager forager) { + MoveSelector moveSelector, RestartStrategy restartStrategy, + Acceptor acceptor, LocalSearchForager forager) { this.logIndentation = logIndentation; this.termination = termination; this.moveSelector = moveSelector; + this.restartStrategy = restartStrategy; this.acceptor = acceptor; this.forager = forager; } @@ -73,18 +77,22 @@ public void setAssertExpectedUndoMoveScore(boolean assertExpectedUndoMoveScore) // ************************************************************************ public void solvingStarted(SolverScope solverScope) { + restartStrategy.solvingStarted(solverScope); moveSelector.solvingStarted(solverScope); acceptor.solvingStarted(solverScope); forager.solvingStarted(solverScope); } public void phaseStarted(LocalSearchPhaseScope phaseScope) { + phaseScope.setDecider(this); + restartStrategy.phaseStarted(phaseScope); moveSelector.phaseStarted(phaseScope); acceptor.phaseStarted(phaseScope); forager.phaseStarted(phaseScope); } public void stepStarted(LocalSearchStepScope stepScope) { + restartStrategy.stepStarted(stepScope); moveSelector.stepStarted(stepScope); acceptor.stepStarted(stepScope); forager.stepStarted(stepScope); @@ -148,23 +156,41 @@ protected void pickMove(LocalSearchStepScope stepScope) { } public void stepEnded(LocalSearchStepScope stepScope) { + if (restartStrategy.isSolverStuck(stepScope)) { + restartStrategy.applyRestart(stepScope); + } + restartStrategy.stepEnded(stepScope); moveSelector.stepEnded(stepScope); acceptor.stepEnded(stepScope); forager.stepEnded(stepScope); } public void phaseEnded(LocalSearchPhaseScope phaseScope) { + restartStrategy.phaseEnded(phaseScope); moveSelector.phaseEnded(phaseScope); acceptor.phaseEnded(phaseScope); forager.phaseEnded(phaseScope); } public void solvingEnded(SolverScope solverScope) { + restartStrategy.solvingEnded(solverScope); moveSelector.solvingEnded(solverScope); acceptor.solvingEnded(solverScope); forager.solvingEnded(solverScope); } + public void setWorkingSolutionFromBestSolution(LocalSearchStepScope stepScope) { + stepScope.getPhaseScope().getSolverScope().setWorkingSolutionFromBestSolution(); + // Adjust the step score to reflect the best score, + // ensuring the score of the last completed step is the current best one + stepScope.setScore(stepScope.getPhaseScope().getBestScore()); + // Changing the working solution requires reinitializing the move selector. + // The acceptor should not be restarted, as this may lead to an inconsistent state, + // such as changing the scores of all late elements in LA and DLAS. + // 1 - The move selector will reset all cached lists using old solution entity references + moveSelector.phaseStarted(stepScope.getPhaseScope()); + } + public void solvingError(SolverScope solverScope, Exception exception) { // Overridable by a subclass. } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactory.java index f1d2232668..2b3561840c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactory.java @@ -17,6 +17,8 @@ import ai.timefold.solver.core.impl.localsearch.decider.acceptor.lateacceptance.LateAcceptanceAcceptor; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.simulatedannealing.SimulatedAnnealingAcceptor; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stepcountinghillclimbing.StepCountingHillClimbingAcceptor; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.DiminishedReturnsStuckCriterion; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.tabu.EntityTabuAcceptor; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.tabu.MoveTabuAcceptor; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.tabu.ValueTabuAcceptor; @@ -215,11 +217,13 @@ private Optional> buildMoveTabuAcceptor(HeuristicCon return Optional.empty(); } - private Optional> buildLateAcceptanceAcceptor() { + private Optional> + buildLateAcceptanceAcceptor() { if (acceptorTypeListsContainsAcceptorType(AcceptorType.LATE_ACCEPTANCE) || (!acceptorTypeListsContainsAcceptorType(AcceptorType.DIVERSIFIED_LATE_ACCEPTANCE) && acceptorConfig.getLateAcceptanceSize() != null)) { - var acceptor = new LateAcceptanceAcceptor(); + StuckCriterion strategy = new DiminishedReturnsStuckCriterion<>(); + var acceptor = new LateAcceptanceAcceptor<>(strategy); acceptor.setLateAcceptanceSize(Objects.requireNonNullElse(acceptorConfig.getLateAcceptanceSize(), 400)); return Optional.of(acceptor); } @@ -230,7 +234,8 @@ private Optional> buildLateAcceptanceAcceptor( buildDiversifiedLateAcceptanceAcceptor(HeuristicConfigPolicy configPolicy) { if (acceptorTypeListsContainsAcceptorType(AcceptorType.DIVERSIFIED_LATE_ACCEPTANCE)) { configPolicy.ensurePreviewFeature(PreviewFeature.DIVERSIFIED_LATE_ACCEPTANCE); - var acceptor = new DiversifiedLateAcceptanceAcceptor(); + StuckCriterion strategy = new DiminishedReturnsStuckCriterion<>(); + var acceptor = new DiversifiedLateAcceptanceAcceptor<>(strategy); acceptor.setLateAcceptanceSize(Objects.requireNonNullElse(acceptorConfig.getLateAcceptanceSize(), 5)); return Optional.of(acceptor); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/RestartableAcceptor.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/RestartableAcceptor.java new file mode 100644 index 0000000000..d914566394 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/RestartableAcceptor.java @@ -0,0 +1,63 @@ +package ai.timefold.solver.core.impl.localsearch.decider.acceptor; + +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; +import ai.timefold.solver.core.impl.solver.scope.SolverScope; + +/** + * Base class designed to analyze whether the solving process needs to be restarted. + * Additionally, it also calls a reconfiguration logic as a result of restarting the solving process. + */ +public abstract class RestartableAcceptor extends AbstractAcceptor { + + private final StuckCriterion stuckCriterion; + protected boolean restartTriggered; + + protected RestartableAcceptor(StuckCriterion stuckCriterion) { + this.stuckCriterion = stuckCriterion; + } + + @Override + public void solvingStarted(SolverScope solverScope) { + super.solvingStarted(solverScope); + stuckCriterion.solvingStarted(solverScope); + } + + @Override + public void phaseStarted(LocalSearchPhaseScope phaseScope) { + super.phaseStarted(phaseScope); + stuckCriterion.phaseStarted(phaseScope); + } + + @Override + public void phaseEnded(LocalSearchPhaseScope phaseScope) { + super.phaseEnded(phaseScope); + stuckCriterion.phaseEnded(phaseScope); + } + + @Override + public void stepStarted(LocalSearchStepScope stepScope) { + super.stepStarted(stepScope); + stuckCriterion.stepStarted(stepScope); + } + + @Override + public void stepEnded(LocalSearchStepScope stepScope) { + super.stepEnded(stepScope); + stuckCriterion.stepEnded(stepScope); + } + + @Override + public boolean isAccepted(LocalSearchMoveScope moveScope) { + if (stuckCriterion.isSolverStuck(moveScope)) { + moveScope.getStepScope().getPhaseScope().setSolverStuck(true); + restartTriggered = true; + return true; + } + return accept(moveScope); + } + + protected abstract boolean accept(LocalSearchMoveScope moveScope); +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptor.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptor.java index b87373414b..1b5f5a70c1 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptor.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptor.java @@ -3,11 +3,12 @@ import java.util.Arrays; import ai.timefold.solver.core.api.score.Score; -import ai.timefold.solver.core.impl.localsearch.decider.acceptor.AbstractAcceptor; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.RestartableAcceptor; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; -public class DiversifiedLateAcceptanceAcceptor extends AbstractAcceptor { +public class DiversifiedLateAcceptanceAcceptor extends RestartableAcceptor { // The worst score in the late elements list protected Score lateWorse; @@ -19,6 +20,10 @@ public class DiversifiedLateAcceptanceAcceptor extends AbstractAccept protected Score[] previousScores; protected int lateScoreIndex = -1; + public DiversifiedLateAcceptanceAcceptor(StuckCriterion stuckCriterionDetection) { + super(stuckCriterionDetection); + } + public void setLateAcceptanceSize(int lateAcceptanceSize) { this.lateAcceptanceSize = lateAcceptanceSize; } @@ -48,7 +53,7 @@ private void validate() { @Override @SuppressWarnings({ "rawtypes", "unchecked" }) - public boolean isAccepted(LocalSearchMoveScope moveScope) { + protected boolean accept(LocalSearchMoveScope moveScope) { // The acceptance and replacement strategies are based on the work: // Diversified Late Acceptance Search by M. Namazi, C. Sanderson, M. A. H. Newton, M. M. A. Polash, and A. Sattar var moveScore = moveScope.getScore(); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptor.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptor.java index 332a1e15a1..6433f3ca3e 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptor.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptor.java @@ -3,12 +3,13 @@ import java.util.Arrays; import ai.timefold.solver.core.api.score.Score; -import ai.timefold.solver.core.impl.localsearch.decider.acceptor.AbstractAcceptor; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.RestartableAcceptor; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; -public class LateAcceptanceAcceptor extends AbstractAcceptor { +public class LateAcceptanceAcceptor extends RestartableAcceptor { protected int lateAcceptanceSize = -1; protected boolean hillClimbingEnabled = true; @@ -16,6 +17,10 @@ public class LateAcceptanceAcceptor extends AbstractAcceptor[] previousScores; protected int lateScoreIndex = -1; + public LateAcceptanceAcceptor(StuckCriterion stuckCriterionDetection) { + super(stuckCriterionDetection); + } + public void setLateAcceptanceSize(int lateAcceptanceSize) { this.lateAcceptanceSize = lateAcceptanceSize; } @@ -47,7 +52,7 @@ private void validate() { @Override @SuppressWarnings("unchecked") - public boolean isAccepted(LocalSearchMoveScope moveScope) { + public boolean accept(LocalSearchMoveScope moveScope) { var moveScore = moveScope.getScore(); var lateScore = previousScores[lateScoreIndex]; if (moveScore.compareTo(lateScore) >= 0) { @@ -63,8 +68,12 @@ public boolean isAccepted(LocalSearchMoveScope moveScope) { @Override public void stepEnded(LocalSearchStepScope stepScope) { super.stepEnded(stepScope); - previousScores[lateScoreIndex] = stepScope.getScore(); - lateScoreIndex = (lateScoreIndex + 1) % lateAcceptanceSize; + if (!restartTriggered) { + previousScores[lateScoreIndex] = stepScope.getScore(); + lateScoreIndex = (lateScoreIndex + 1) % lateAcceptanceSize; + } else { + restartTriggered = false; + } } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/AbstractGeometricStuckCriterion.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/AbstractGeometricStuckCriterion.java new file mode 100644 index 0000000000..179fec99fd --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/AbstractGeometricStuckCriterion.java @@ -0,0 +1,73 @@ +package ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion; + +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; +import ai.timefold.solver.core.impl.solver.scope.SolverScope; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Restart strategy, which exponentially increases the metric that triggers the restart process. + * The first restart occurs after the {@code scalingFactor * GEOMETRIC_FACTOR^restartCount} metric. + * Following that, the metric increases exponentially: 1, 2, 3, 5, 7, 10, 14... + *

+ * The strategy is based on the work: Search in a Small World by Toby Walsh + * + * @param the solution type + */ +public abstract class AbstractGeometricStuckCriterion implements StuckCriterion { + protected static final Logger logger = LoggerFactory.getLogger(AbstractGeometricStuckCriterion.class); + private static final double GEOMETRIC_FACTOR = 1.4; // Value extracted from the cited paper + + private final double scalingFactor; + protected long nextRestart; + private double currentGeometricGrowFactor; + + protected AbstractGeometricStuckCriterion(double scalingFactor) { + this.scalingFactor = scalingFactor; + } + + @Override + public void phaseStarted(LocalSearchPhaseScope phaseScope) { + // Do nothing + } + + @Override + public void phaseEnded(LocalSearchPhaseScope phaseScope) { + // Do nothing + } + + @Override + public void solvingStarted(SolverScope solverScope) { + currentGeometricGrowFactor = 1; + nextRestart = calculateNextRestart(); + } + + @Override + public void solvingEnded(SolverScope solverScope) { + // Do nothing + } + + @Override + public boolean isSolverStuck(LocalSearchMoveScope moveScope) { + var triggered = evaluateCriterion(moveScope); + if (triggered) { + logger.info( + "Restart triggered with geometric factor ({}), scaling factor of ({}), nextRestart ({}), best score ({})", + currentGeometricGrowFactor, scalingFactor, nextRestart, + moveScope.getStepScope().getPhaseScope().getBestScore()); + currentGeometricGrowFactor = Math.ceil(currentGeometricGrowFactor * GEOMETRIC_FACTOR); + nextRestart = calculateNextRestart(); + return true; + } + return false; + } + + private long calculateNextRestart() { + return (long) Math.ceil(currentGeometricGrowFactor * scalingFactor); + } + + abstract boolean evaluateCriterion(LocalSearchMoveScope moveScope); + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/DiminishedReturnsStuckCriterion.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/DiminishedReturnsStuckCriterion.java new file mode 100644 index 0000000000..18e7fa2b5b --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/DiminishedReturnsStuckCriterion.java @@ -0,0 +1,89 @@ +package ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; +import ai.timefold.solver.core.impl.solver.scope.SolverScope; +import ai.timefold.solver.core.impl.solver.termination.DiminishedReturnsTermination; + +public class DiminishedReturnsStuckCriterion> + extends AbstractGeometricStuckCriterion { + protected static final long TIME_WINDOW_MILLIS = 60_000; + private static final double MINIMAL_IMPROVEMENT = 0.0001; + + private DiminishedReturnsTermination diminishedReturnsCriterion; + + private boolean triggered; + private Score_ currentBestScore; + + public DiminishedReturnsStuckCriterion() { + this(new DiminishedReturnsTermination<>(TIME_WINDOW_MILLIS, MINIMAL_IMPROVEMENT)); + } + + protected DiminishedReturnsStuckCriterion(DiminishedReturnsTermination diminishedReturnsCriterion) { + super(TIME_WINDOW_MILLIS); + this.diminishedReturnsCriterion = diminishedReturnsCriterion; + } + + @Override + @SuppressWarnings("unchecked") + boolean evaluateCriterion(LocalSearchMoveScope moveScope) { + var bestScore = moveScope.getStepScope().getPhaseScope().getBestScore(); + if (moveScope.getScore().compareTo(bestScore) > 0) { + bestScore = moveScope.getScore(); + } + triggered = diminishedReturnsCriterion.isTerminated(System.nanoTime(), (Score_) bestScore); + return triggered; + } + + @Override + public void stepStarted(LocalSearchStepScope stepScope) { + currentBestScore = stepScope.getPhaseScope().getBestScore(); + if (triggered) { + // We need to recreate the termination criterion as the time window has changed + diminishedReturnsCriterion = new DiminishedReturnsTermination<>(nextRestart, MINIMAL_IMPROVEMENT); + diminishedReturnsCriterion.start(System.nanoTime(), stepScope.getPhaseScope().getBestScore()); + triggered = false; + } + } + + @Override + public void stepEnded(LocalSearchStepScope stepScope) { + diminishedReturnsCriterion.stepEnded(stepScope); + if (currentBestScore.compareTo(stepScope.getPhaseScope().getBestScore()) < 0 && nextRestart > TIME_WINDOW_MILLIS) { + // If the solution has been improved after a restart, + // we reset the criterion and restart the evaluation of the metric + super.solvingStarted(stepScope.getPhaseScope().getSolverScope()); + diminishedReturnsCriterion = new DiminishedReturnsTermination<>(nextRestart, MINIMAL_IMPROVEMENT); + diminishedReturnsCriterion.start(System.nanoTime(), stepScope.getPhaseScope().getBestScore()); + logger.info("Stuck criterion reset, next restart ({}), previous best score({}), new best score ({})", nextRestart, + currentBestScore, stepScope.getPhaseScope().getBestScore()); + } + } + + @Override + public void phaseStarted(LocalSearchPhaseScope phaseScope) { + super.phaseStarted(phaseScope); + diminishedReturnsCriterion.phaseStarted(phaseScope); + triggered = false; + } + + @Override + public void phaseEnded(LocalSearchPhaseScope phaseScope) { + super.phaseEnded(phaseScope); + diminishedReturnsCriterion.phaseEnded(phaseScope); + } + + @Override + public void solvingStarted(SolverScope solverScope) { + super.solvingStarted(solverScope); + diminishedReturnsCriterion.solvingStarted(solverScope); + } + + @Override + public void solvingEnded(SolverScope solverScope) { + super.solvingEnded(solverScope); + diminishedReturnsCriterion.solvingEnded(solverScope); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/StuckCriterion.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/StuckCriterion.java new file mode 100644 index 0000000000..0d06ac35d7 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/StuckCriterion.java @@ -0,0 +1,21 @@ +package ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion; + +import ai.timefold.solver.core.api.solver.Solver; +import ai.timefold.solver.core.impl.localsearch.event.LocalSearchPhaseLifecycleListener; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; + +/** + * Allow defining strategies that identify when the {@link Solver solver} is stuck. + * + * @param the solution type + */ +public interface StuckCriterion extends LocalSearchPhaseLifecycleListener { + + /** + * Main logic that applies a specific metric to determine if a solver is stuck in a local optimum. + * + * @param moveScope cannot be null + * @return + */ + boolean isSolverStuck(LocalSearchMoveScope moveScope); +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestartStrategy.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestartStrategy.java new file mode 100644 index 0000000000..135abfd8c8 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestartStrategy.java @@ -0,0 +1,39 @@ +package ai.timefold.solver.core.impl.localsearch.decider.restart; + +import ai.timefold.solver.core.impl.localsearch.decider.LocalSearchDecider; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; +import ai.timefold.solver.core.impl.phase.event.PhaseLifecycleListener; +import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope; + +/** + * Base contract for defining restart strategies. + * The restart process is initiated + * when the {@link LocalSearchDecider decider} identifies + * that the solver is {@link StuckCriterion stuck} + * and requires some logic to alter the current solving flow. + * + * @param the solution type + */ +public sealed interface RestartStrategy extends PhaseLifecycleListener + permits RestoreBestSolutionRestartStrategy { + + /** + * Restarts the solver to help it to get unstuck and discover new better solutions. + * + * @param stepScope cannot be null + */ + void applyRestart(AbstractStepScope stepScope); + + /** + * Evaluates whether the solver is stuck based on specific criteria + * and determines if reconfiguration logic needs to be applied. + * + * @param stepScope cannot be null + * + * @return true if the solver needs to be reconfigured; or false otherwise + */ + default boolean isSolverStuck(AbstractStepScope stepScope) { + return stepScope.getPhaseScope().isSolverStuck(); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestoreBestSolutionRestartStrategy.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestoreBestSolutionRestartStrategy.java new file mode 100644 index 0000000000..0f1b8c2751 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestoreBestSolutionRestartStrategy.java @@ -0,0 +1,58 @@ +package ai.timefold.solver.core.impl.localsearch.decider.restart; + +import java.util.Objects; + +import ai.timefold.solver.core.impl.localsearch.decider.LocalSearchDecider; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; +import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope; +import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope; +import ai.timefold.solver.core.impl.solver.scope.SolverScope; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class RestoreBestSolutionRestartStrategy implements RestartStrategy { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + private LocalSearchDecider decider; + + @Override + public void applyRestart(AbstractStepScope stepScope) { + var solverScope = stepScope.getPhaseScope().getSolverScope(); + logger.trace("Resetting working solution, score ({})", solverScope.getBestScore()); + decider.setWorkingSolutionFromBestSolution((LocalSearchStepScope) stepScope); + // Mark the solver as unstuck as the best solution is already restored + stepScope.getPhaseScope().setSolverStuck(false); + } + + @Override + public void stepStarted(AbstractStepScope stepScope) { + // Do nothing + } + + @Override + public void stepEnded(AbstractStepScope stepScope) { + // Do nothing + } + + @Override + public void phaseStarted(AbstractPhaseScope phaseScope) { + this.decider = Objects.requireNonNull(((LocalSearchPhaseScope) phaseScope).getDecider()); + } + + @Override + public void phaseEnded(AbstractPhaseScope phaseScope) { + // Do nothing + } + + @Override + public void solvingStarted(SolverScope solverScope) { + // Do nothing + } + + @Override + public void solvingEnded(SolverScope solverScope) { + // Do nothing + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/scope/LocalSearchPhaseScope.java b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/scope/LocalSearchPhaseScope.java index 6fa927ede5..0aecf364c5 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/localsearch/scope/LocalSearchPhaseScope.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/localsearch/scope/LocalSearchPhaseScope.java @@ -1,6 +1,7 @@ package ai.timefold.solver.core.impl.localsearch.scope; import ai.timefold.solver.core.api.domain.solution.PlanningSolution; +import ai.timefold.solver.core.impl.localsearch.decider.LocalSearchDecider; import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope; import ai.timefold.solver.core.impl.solver.scope.SolverScope; @@ -9,6 +10,7 @@ */ public final class LocalSearchPhaseScope extends AbstractPhaseScope { + private LocalSearchDecider decider; private LocalSearchStepScope lastCompletedStepScope; public LocalSearchPhaseScope(SolverScope solverScope, int phaseIndex) { @@ -26,6 +28,14 @@ public void setLastCompletedStepScope(LocalSearchStepScope lastComple this.lastCompletedStepScope = lastCompletedStepScope; } + public LocalSearchDecider getDecider() { + return decider; + } + + public void setDecider(LocalSearchDecider decider) { + this.decider = decider; + } + // ************************************************************************ // Calculated methods // ************************************************************************ diff --git a/core/src/main/java/ai/timefold/solver/core/impl/phase/scope/AbstractPhaseScope.java b/core/src/main/java/ai/timefold/solver/core/impl/phase/scope/AbstractPhaseScope.java index 02af1c1821..9823b962b2 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/phase/scope/AbstractPhaseScope.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/phase/scope/AbstractPhaseScope.java @@ -34,6 +34,7 @@ public abstract class AbstractPhaseScope { protected long childThreadsScoreCalculationCount = 0L; protected int bestSolutionStepIndex; + protected boolean solverStuck = false; /** * As defined by #AbstractPhaseScope(SolverScope, int, boolean) @@ -247,6 +248,14 @@ public int getNextStepIndex() { return getLastCompletedStepScope().getStepIndex() + 1; } + public boolean isSolverStuck() { + return this.solverStuck; + } + + public void setSolverStuck(boolean solverStuck) { + this.solverStuck = solverStuck; + } + @Override public String toString() { return getClass().getSimpleName() + "(" + phaseIndex + ")"; diff --git a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactoryTest.java b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactoryTest.java index 97663a54f0..3bcf659abf 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactoryTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/AcceptorFactoryTest.java @@ -1,12 +1,8 @@ package ai.timefold.solver.core.impl.localsearch.decider.acceptor; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import java.util.Arrays; import java.util.List; diff --git a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptorTest.java b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptorTest.java index a30690e1c3..523168feab 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptorTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/DiversifiedLateAcceptanceAcceptorTest.java @@ -1,9 +1,12 @@ package ai.timefold.solver.core.impl.localsearch.decider.acceptor.lateacceptance; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.AbstractAcceptorTest; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; import ai.timefold.solver.core.impl.solver.scope.SolverScope; @@ -14,7 +17,9 @@ class DiversifiedLateAcceptanceAcceptorTest extends AbstractAcceptorTest { @Test void acceptanceCriterion() { - var acceptor = new DiversifiedLateAcceptanceAcceptor<>(); + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(false); + var acceptor = new DiversifiedLateAcceptanceAcceptor<>(restartStrategy); acceptor.setLateAcceptanceSize(3); var solverScope = new SolverScope<>(); @@ -51,7 +56,9 @@ void acceptanceCriterion() { @Test void replacementCriterion() { - var acceptor = new DiversifiedLateAcceptanceAcceptor<>(); + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(false); + var acceptor = new DiversifiedLateAcceptanceAcceptor<>(restartStrategy); acceptor.setLateAcceptanceSize(3); var solverScope = new SolverScope<>(); @@ -142,4 +149,19 @@ void replacementCriterion() { acceptor.isAccepted(moveScope0); assertThat(acceptor.previousScores[0]).isEqualTo(SimpleScore.of(-2001)); } + + @Test + void triggerReconfiguration() { + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(true); + var acceptor = new DiversifiedLateAcceptanceAcceptor<>(restartStrategy); + acceptor.setLateAcceptanceSize(3); + var solverScope = new SolverScope<>(); + var phaseScope = new LocalSearchPhaseScope<>(solverScope, 0); + var stepScope0 = new LocalSearchStepScope<>(phaseScope); + var moveScope0 = buildMoveScope(stepScope0, -2000); + assertThat(acceptor.isAccepted(moveScope0)).isTrue(); + verify(restartStrategy, times(1)).isSolverStuck(any()); + assertThat(phaseScope.isSolverStuck()).isTrue(); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptorTest.java b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptorTest.java index dc245c652f..a0d04cb89a 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptorTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/lateacceptance/LateAcceptanceAcceptorTest.java @@ -2,9 +2,13 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; import ai.timefold.solver.core.impl.localsearch.decider.acceptor.AbstractAcceptorTest; +import ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.StuckCriterion; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; import ai.timefold.solver.core.impl.solver.scope.SolverScope; @@ -15,7 +19,9 @@ class LateAcceptanceAcceptorTest extends AbstractAcceptorTest { @Test void lateAcceptanceSize() { - var acceptor = new LateAcceptanceAcceptor<>(); + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(false); + var acceptor = new LateAcceptanceAcceptor<>(restartStrategy); acceptor.setLateAcceptanceSize(3); acceptor.setHillClimbingEnabled(false); @@ -128,7 +134,9 @@ void lateAcceptanceSize() { @Test void hillClimbingEnabled() { - var acceptor = new LateAcceptanceAcceptor<>(); + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(false); + var acceptor = new LateAcceptanceAcceptor<>(restartStrategy); acceptor.setLateAcceptanceSize(2); acceptor.setHillClimbingEnabled(true); @@ -241,15 +249,33 @@ void hillClimbingEnabled() { @Test void zeroLateAcceptanceSize() { - var acceptor = new LateAcceptanceAcceptor<>(); + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(false); + var acceptor = new LateAcceptanceAcceptor<>(restartStrategy); acceptor.setLateAcceptanceSize(0); assertThatIllegalArgumentException().isThrownBy(() -> acceptor.phaseStarted(null)); } @Test void negativeLateAcceptanceSize() { - var acceptor = new LateAcceptanceAcceptor<>(); + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(false); + var acceptor = new LateAcceptanceAcceptor<>(restartStrategy); acceptor.setLateAcceptanceSize(-1); assertThatIllegalArgumentException().isThrownBy(() -> acceptor.phaseStarted(null)); } + + @Test + void triggerReconfiguration() { + var restartStrategy = mock(StuckCriterion.class); + when(restartStrategy.isSolverStuck(any())).thenReturn(true); + var acceptor = new LateAcceptanceAcceptor<>(restartStrategy); + acceptor.setLateAcceptanceSize(3); + var solverScope = new SolverScope<>(); + var phaseScope = new LocalSearchPhaseScope<>(solverScope, 0); + var stepScope0 = new LocalSearchStepScope<>(phaseScope); + var moveScope0 = buildMoveScope(stepScope0, -2000); + assertThat(acceptor.isAccepted(moveScope0)).isTrue(); + assertThat(phaseScope.isSolverStuck()).isTrue(); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/DiminishedReturnsStuckCriterionTest.java b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/DiminishedReturnsStuckCriterionTest.java new file mode 100644 index 0000000000..1ba0c0f594 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/acceptor/stuckcriterion/DiminishedReturnsStuckCriterionTest.java @@ -0,0 +1,79 @@ +package ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion; + +import static ai.timefold.solver.core.impl.localsearch.decider.acceptor.stuckcriterion.DiminishedReturnsStuckCriterion.TIME_WINDOW_MILLIS; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchMoveScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; +import ai.timefold.solver.core.impl.solver.scope.SolverScope; +import ai.timefold.solver.core.impl.solver.termination.DiminishedReturnsTermination; + +import org.junit.jupiter.api.Test; + +class DiminishedReturnsStuckCriterionTest { + + @Test + void isSolverStuck() { + var solverScope = mock(SolverScope.class); + var phaseScope = mock(LocalSearchPhaseScope.class); + var stepScope = mock(LocalSearchStepScope.class); + var moveScope = mock(LocalSearchMoveScope.class); + var termination = mock(DiminishedReturnsTermination.class); + + when(moveScope.getStepScope()).thenReturn(stepScope); + when(stepScope.getPhaseScope()).thenReturn(phaseScope); + when(phaseScope.getSolverScope()).thenReturn(solverScope); + when(moveScope.getScore()).thenReturn(SimpleScore.of(1)); + when(phaseScope.getBestScore()).thenReturn(SimpleScore.of(1)); + when(termination.isTerminated(anyLong(), any())).thenReturn(false, true); + + // No restart + var strategy = new DiminishedReturnsStuckCriterion<>(termination); + strategy.solvingStarted(null); + strategy.phaseStarted(phaseScope); + assertThat(strategy.isSolverStuck(moveScope)).isFalse(); + + // First restart + assertThat(strategy.isSolverStuck(moveScope)).isTrue(); + assertThat(strategy.nextRestart).isEqualTo(2L * TIME_WINDOW_MILLIS); + + // Second restart + assertThat(strategy.isSolverStuck(moveScope)).isTrue(); + assertThat(strategy.nextRestart).isEqualTo(3L * TIME_WINDOW_MILLIS); + } + + @Test + void reset() { + var solverScope = mock(SolverScope.class); + var phaseScope = mock(LocalSearchPhaseScope.class); + var stepScope = mock(LocalSearchStepScope.class); + var moveScope = mock(LocalSearchMoveScope.class); + var termination = mock(DiminishedReturnsTermination.class); + + when(moveScope.getStepScope()).thenReturn(stepScope); + when(stepScope.getPhaseScope()).thenReturn(phaseScope); + when(phaseScope.getSolverScope()).thenReturn(solverScope); + when(moveScope.getScore()).thenReturn(SimpleScore.of(1)); + when(phaseScope.getBestScore()).thenReturn(SimpleScore.of(1)); + when(termination.isTerminated(anyLong(), any())).thenReturn(true); + + // Restart + var strategy = new DiminishedReturnsStuckCriterion<>(termination); + strategy.solvingStarted(null); + strategy.phaseStarted(phaseScope); + assertThat(strategy.isSolverStuck(moveScope)).isTrue(); + assertThat(strategy.nextRestart).isEqualTo(2L * TIME_WINDOW_MILLIS); + + // Reset + strategy.stepStarted(stepScope); + when(phaseScope.getBestScore()).thenReturn(SimpleScore.of(2)); + strategy.stepEnded(stepScope); + assertThat(strategy.nextRestart).isEqualTo(TIME_WINDOW_MILLIS); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestartStrategyTest.java b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestartStrategyTest.java new file mode 100644 index 0000000000..1614155a31 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/localsearch/decider/restart/RestartStrategyTest.java @@ -0,0 +1,37 @@ +package ai.timefold.solver.core.impl.localsearch.decider.restart; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; + +import ai.timefold.solver.core.impl.localsearch.decider.LocalSearchDecider; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchPhaseScope; +import ai.timefold.solver.core.impl.localsearch.scope.LocalSearchStepScope; +import ai.timefold.solver.core.impl.solver.scope.SolverScope; + +import org.junit.jupiter.api.Test; + +class RestartStrategyTest { + + @Test + void restoreBestSolution() { + // Requires the decider + var badStrategy = new RestoreBestSolutionRestartStrategy<>(); + assertThatThrownBy(() -> badStrategy.phaseStarted(null)).isInstanceOf(NullPointerException.class); + + // Restore the best solution + var strategy = new RestoreBestSolutionRestartStrategy<>(); + + var decider = mock(LocalSearchDecider.class); + var solverScope = mock(SolverScope.class); + var phaseScope = mock(LocalSearchPhaseScope.class); + when(phaseScope.getSolverScope()).thenReturn(solverScope); + when(phaseScope.getDecider()).thenReturn(decider); + var stepScope = mock(LocalSearchStepScope.class); + when(stepScope.getPhaseScope()).thenReturn(phaseScope); + strategy.solvingStarted(solverScope); + strategy.phaseStarted(phaseScope); + strategy.applyRestart(stepScope); + // Restore the best solution + verify(decider, times(1)).setWorkingSolutionFromBestSolution(any()); + } +}