diff --git a/src/tranquilo/acceptance_decision.py b/src/tranquilo/acceptance_decision.py index 6d62e8a..2384958 100644 --- a/src/tranquilo/acceptance_decision.py +++ b/src/tranquilo/acceptance_decision.py @@ -65,22 +65,29 @@ def accept_greedy( wrapped_criterion({candidate_index: 1}) candidate_fval = np.mean(history.get_fvals(candidate_index)) - actual_improvement = -(candidate_fval - state.fval) + candidate_improvement = -(candidate_fval - state.fval) rho = calculate_rho( - actual_improvement=actual_improvement, + actual_improvement=candidate_improvement, expected_improvement=subproblem_solution.expected_improvement, ) best_x, best_fval, best_index = history.get_best() - - if best_fval < candidate_fval: + + assert np.isfinite(best_fval) + assert isinstance(best_x, np.ndarray) + assert isinstance(best_index, int) + assert isinstance(best_fval, float) + assert best_x.ndim == 1 + assert np.mean(history.get_fvals(best_index)) == best_fval + + if best_fval < candidate_fval and best_fval < state.fval: candidate_x = best_x candidate_fval = best_fval candidate_index = best_index - overall_improvement = -(candidate_fval - state.fval) + overall_improvement = -(best_fval - state.fval) else: - overall_improvement = actual_improvement + overall_improvement = candidate_improvement is_accepted = overall_improvement >= min_improvement diff --git a/src/tranquilo/history.py b/src/tranquilo/history.py index 96841a5..39f6340 100644 --- a/src/tranquilo/history.py +++ b/src/tranquilo/history.py @@ -187,7 +187,7 @@ def get_best(self): """ fvals = self.get_fvals(np.arange(self.n_xs)) average_fvals = {key: np.mean(val) for key, val in fvals.items()} - index = pd.Series(average_fvals).idxmin() + index = int(pd.Series(average_fvals).idxmin()) return self.get_xs(index), average_fvals[index], index def get_n_evals(self, x_indices): diff --git a/tests/test_history.py b/tests/test_history.py index 13177a6..ab9a548 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -233,5 +233,6 @@ def test_get_model_data_with_repeated_evaluations(noisy_history, average): def test_get_best(noisy_history): x, fval, index = noisy_history.get_best() assert index == 0 + assert isinstance(index, int) assert fval == 142.5 aaae(x, np.array([0, 1, 2]))