Skip to content

Commit

Permalink
Add method get_best to history class
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 21, 2024
1 parent 4a3335e commit 5674335
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 0 deletions.
103 changes: 103 additions & 0 deletions src/tranquilo/acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_acceptance_decider(
"naive_noisy": accept_naive_noisy,
"noisy": accept_noisy,
"classic_line_search": accept_classic_line_search,
"greedy": accept_greedy,
}

out = get_component(
Expand All @@ -38,6 +39,38 @@ def get_acceptance_decider(
return out


def accept_greedy(
subproblem_solution,
state,
history,
*,
wrapped_criterion,
min_improvement,
):
"""Do a greedy acceptance step for a trustregion algorithm.
Args:
subproblem_solution (SubproblemResult): Result of the subproblem solution.
state (State): Namedtuple containing the trustregion, criterion value of
previously accepted point, indices of model points, etc.
wrapped_criterion (callable): The criterion function.
min_improvement (float): Minimum improvement required to accept a point.
Returns:
AcceptanceResult
"""
out = _accept_greedy(

Check warning on line 63 in src/tranquilo/acceptance_decision.py

View check run for this annotation

Codecov / codecov/patch

src/tranquilo/acceptance_decision.py#L63

Added line #L63 was not covered by tests
subproblem_solution=subproblem_solution,
state=state,
history=history,
wrapped_criterion=wrapped_criterion,
min_improvement=min_improvement,
n_evals=1,
)
return out

Check warning on line 71 in src/tranquilo/acceptance_decision.py

View check run for this annotation

Codecov / codecov/patch

src/tranquilo/acceptance_decision.py#L71

Added line #L71 was not covered by tests


def _accept_classic(
subproblem_solution,
state,
Expand Down Expand Up @@ -239,6 +272,76 @@ def accept_classic_line_search(
return res


def _accept_greedy(
subproblem_solution,
state,
history,
*,
wrapped_criterion,
min_improvement,
n_evals,
):
"""Do a simple greedy acceptance step for a trustregion algorithm.
Args:
subproblem_solution (SubproblemResult): Result of the subproblem solution.
state (State): Namedtuple containing the trustregion, criterion value of
previously accepted point, indices of model points, etc.
wrapped_criterion (callable): The criterion function.
min_improvement (float): Minimum improvement required to accept a point.
Returns:
AcceptanceResult
"""
candidate_x = subproblem_solution.x
candidate_index = history.add_xs(candidate_x)
wrapped_criterion({candidate_index: n_evals})

candidate_fval = np.mean(history.get_fvals(candidate_index))
actual_improvement = -(candidate_fval - state.fval)

rho = calculate_rho(
actual_improvement=actual_improvement,
expected_improvement=subproblem_solution.expected_improvement,
)

best_x, best_fval, best_index = history.get_best()

if best_fval < candidate_fval:
candidate_x = best_x
candidate_fval = best_fval
candidate_index = best_index
overall_improvement = -(candidate_fval - state.fval)

Check warning on line 315 in src/tranquilo/acceptance_decision.py

View check run for this annotation

Codecov / codecov/patch

src/tranquilo/acceptance_decision.py#L312-L315

Added lines #L312 - L315 were not covered by tests
else:
overall_improvement = actual_improvement

is_accepted = overall_improvement >= min_improvement

if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=n_evals,
)
else:
res = _get_acceptance_result(

Check warning on line 332 in src/tranquilo/acceptance_decision.py

View check run for this annotation

Codecov / codecov/patch

src/tranquilo/acceptance_decision.py#L332

Added line #L332 was not covered by tests
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=n_evals,
)

return res


def _accept_simple(
subproblem_solution,
state,
Expand Down
37 changes: 37 additions & 0 deletions tests/test_acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from tranquilo.sample_points import get_sampler
from tranquilo.acceptance_decision import (
_accept_greedy,
_accept_simple,
_get_acceptance_result,
calculate_rho,
Expand Down Expand Up @@ -82,6 +83,42 @@ def wrapped_criterion(eval_info):
assert_array_equal(res_got.candidate_x, 1.0 + np.arange(2))


# ======================================================================================
# Test accept_greedy
# ======================================================================================


@pytest.mark.parametrize("state", states)
def test_accept_greedy(
state,
subproblem_solution,
):
history = History(functype="scalar")

idxs = history.add_xs(np.arange(10).reshape(5, 2))

history.add_evals(idxs.repeat(2), np.arange(10))

def wrapped_criterion(eval_info):
indices = np.array(list(eval_info)).repeat(np.array(list(eval_info.values())))
history.add_evals(indices, -indices)

res_got = _accept_greedy(
subproblem_solution=subproblem_solution,
state=state,
history=history,
wrapped_criterion=wrapped_criterion,
min_improvement=0.0,
n_evals=2,
)

assert res_got.accepted
assert res_got.index == 5
assert res_got.candidate_index == 5
assert_array_equal(res_got.x, subproblem_solution.x)
assert_array_equal(res_got.candidate_x, 1.0 + np.arange(2))


# ======================================================================================
# Test _get_acceptance_result
# ======================================================================================
Expand Down

0 comments on commit 5674335

Please sign in to comment.