Skip to content

Commit

Permalink
Improve error handling (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Jan 16, 2024
1 parent 7042550 commit a94c8b5
Show file tree
Hide file tree
Showing 11 changed files with 403 additions and 143 deletions.
90 changes: 62 additions & 28 deletions src/tranquilo/acceptance_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,29 @@ def accept_classic_line_search(
is_accepted = overall_improvement >= min_improvement

# ==================================================================================
# Return results
# Process and return results

if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=1,
)
else:
res = _get_acceptance_result(
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=1,
)

res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=1,
)
return res


Expand Down Expand Up @@ -262,15 +274,26 @@ def _accept_simple(

is_accepted = actual_improvement >= min_improvement

res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=n_evals,
)
if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=n_evals,
)
else:
res = _get_acceptance_result(
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=n_evals,
)

return res

Expand Down Expand Up @@ -321,15 +344,26 @@ def accept_noisy(

is_accepted = actual_improvement >= min_improvement

res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=n_2,
)
if np.isfinite(candidate_fval):
res = _get_acceptance_result(
candidate_x=candidate_x,
candidate_fval=candidate_fval,
candidate_index=candidate_index,
rho=rho,
is_accepted=is_accepted,
old_state=state,
n_evals=n_2,
)
else:
res = _get_acceptance_result(
candidate_x=state.x,
candidate_fval=state.fval,
candidate_index=state.index,
rho=-np.inf,
is_accepted=False,
old_state=state,
n_evals=n_2,
)

return res

Expand Down
2 changes: 2 additions & 0 deletions src/tranquilo/process_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def process_arguments(
model_fitter_options=None,
cube_subsolver="bntr_fast",
sphere_subsolver="gqtpar_fast",
retry_subproblem_with_fallback=True,
subsolver_options=None,
acceptance_decider=None,
acceptance_decider_options=None,
Expand Down Expand Up @@ -189,6 +190,7 @@ def process_arguments(
solve_subproblem = get_subsolver(
cube_solver=cube_subsolver,
sphere_solver=sphere_subsolver,
retry_with_fallback=retry_subproblem_with_fallback,
user_options=subsolver_options,
)

Expand Down
82 changes: 65 additions & 17 deletions src/tranquilo/solve_subproblem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
gqtpar,
)
from tranquilo.subsolvers.gqtpar_fast import gqtpar_fast
from tranquilo.subsolvers.wrapped_subsolvers import (
slsqp_sphere,
solve_multistart,
)
from tranquilo.options import SubsolverOptions
from tranquilo.subsolvers.fallback_subsolvers import (
robust_cube_solver,
robust_sphere_solver_inscribed_cube,
robust_sphere_solver_norm_constraint,
robust_sphere_solver_reparametrized,
robust_cube_solver_multistart,
)


def get_subsolver(sphere_solver, cube_solver, user_options=None):
def get_subsolver(sphere_solver, cube_solver, retry_with_fallback, user_options=None):
"""Get an algorithm-function with partialled options.
Args:
Expand All @@ -36,6 +39,8 @@ def get_subsolver(sphere_solver, cube_solver, user_options=None):
to be ``lower_bounds`` and ``upper_bounds``. The fourth argument needs to be
``x_candidate``, an initial guess for the solution in the unit space.
Moreover, subsolvers can have any number of additional keyword arguments.
retry_with_fallback (bool): Whether to retry solving the subproblem with a
fallback solver if the optimized subsolver raises an exception.
user_options (dict):
Options for the subproblem solver. The following are supported:
- maxiter (int): Maximum number of iterations to perform when solving the
Expand Down Expand Up @@ -70,13 +75,16 @@ def get_subsolver(sphere_solver, cube_solver, user_options=None):
built_in_sphere_solvers = {
"gqtpar": gqtpar,
"gqtpar_fast": gqtpar_fast,
"slsqp_sphere": slsqp_sphere,
"fallback_reparametrized": robust_sphere_solver_reparametrized,
"fallback_inscribed_cube": robust_sphere_solver_inscribed_cube,
"fallback_norm_constraint": robust_sphere_solver_norm_constraint,
}

built_in_cube_solvers = {
"bntr": bntr,
"bntr_fast": bntr_fast,
"multistart": solve_multistart,
"fallback_cube": robust_cube_solver,
"fallback_multistart": robust_cube_solver_multistart,
}

_sphere_subsolver = get_component(
Expand All @@ -97,10 +105,29 @@ def get_subsolver(sphere_solver, cube_solver, user_options=None):
mandatory_signature=["model", "x_candidate", "lower_bounds", "upper_bounds"],
)

_fallback_sphere_solver = get_component(
name_or_func="fallback_inscribed_cube",
component_name="fallback_sphere_solver",
func_dict=built_in_sphere_solvers,
default_options=SubsolverOptions(),
mandatory_signature=["model", "x_candidate"],
)

_fallback_cube_solver = get_component(
name_or_func="fallback_cube",
component_name="fallback_cube_solver",
func_dict=built_in_cube_solvers,
default_options=SubsolverOptions(),
mandatory_signature=["model", "x_candidate"],
)

solver = partial(
_solve_subproblem_template,
sphere_solver=_sphere_subsolver,
cube_solver=_cube_subsolver,
fallback_sphere_solver=_fallback_sphere_solver,
fallback_cube_solver=_fallback_cube_solver,
retry_with_fallback=retry_with_fallback,
)

return solver
Expand All @@ -111,6 +138,9 @@ def _solve_subproblem_template(
trustregion,
sphere_solver,
cube_solver,
fallback_sphere_solver,
fallback_cube_solver,
retry_with_fallback,
):
"""Solve the quadratic subproblem.
Expand All @@ -128,6 +158,8 @@ def _solve_subproblem_template(
``upper_bounds``. The fourth argument needs to be ``x_candidate``, an
initial guess for the solution in the unit space. Moreover, subsolvers can
have any number of additional keyword arguments.
retry_with_fallback (bool): Whether to retry solving the subproblem with a
fallback solver if the optimized subsolver raises an exception.
Returns:
Expand All @@ -145,16 +177,32 @@ def _solve_subproblem_template(
"""
old_x_unit = trustregion.map_to_unit(trustregion.center)

solver = sphere_solver if trustregion.shape == "sphere" else cube_solver

raw_result = solver(
model=model,
x_candidate=old_x_unit,
# bounds can be passed to both solvers because the functions returned by
# `get_component` ignore redundant arguments.
lower_bounds=-np.ones_like(old_x_unit),
upper_bounds=np.ones_like(old_x_unit),
)
if trustregion.shape == "sphere":
solver = sphere_solver
fallback_solver = fallback_sphere_solver
else:
solver = cube_solver
fallback_solver = fallback_cube_solver

# try finding a solution using the optimized subsolver. If this raises an
# exception, use a fallback solver if requested, otherwise raise the exception.
try:
raw_result = solver(
model=model,
x_candidate=old_x_unit,
# bounds can be passed to both solvers because the functions returned by
# `get_component` ignore redundant arguments.
lower_bounds=-np.ones_like(old_x_unit),
upper_bounds=np.ones_like(old_x_unit),
)
except Exception as e:
if not retry_with_fallback:
raise e

raw_result = fallback_solver(
model=model,
x_candidate=old_x_unit,
)

if trustregion.shape == "cube":
raw_result["x"] = np.clip(raw_result["x"], -1.0, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion src/tranquilo/subsolvers/bntr_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _take_preliminary_gradient_descent_step_and_check_for_solution(

if not converged:
trustregion_radius = min(
max(min_radius, max(trustregion_radius, radius_lower_bound)), max_radius
max(min_radius, trustregion_radius, radius_lower_bound), max_radius
)

return (
Expand Down
Loading

0 comments on commit a94c8b5

Please sign in to comment.