From d21c9fcd28201f41dc388a35ef0bbed0a7639ee1 Mon Sep 17 00:00:00 2001 From: robogast Date: Tue, 1 Mar 2022 10:43:32 +0100 Subject: [PATCH] Revert "cleaned up some exception handling logic" This reverts commit ea93e4727385277cb71499831425591fd00b42b6. --- .../hydra_optuna_sweeper/_impl.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py index 4cb89b717e4..ac05835b43c 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py @@ -214,25 +214,27 @@ def sweep(self, arguments: List[str]) -> None: values: Optional[List[float]] = None state: optuna.trial.TrialState = optuna.trial.TrialState.COMPLETE try: - try: - values = [float(v) for v in ( - [ret.return_values] if len(directions) == 1 else ret.return_values - )] - except (ValueError, TypeError): - raise ValueError( - "Return value(s) must be float-castable," - " and a sequence if multiple objectives are used." - f" Got '{ret.return_value}'." - ).with_traceback(sys.exc_info()[2]) - - if len(values) != len(directions): - raise ValueError( - "The number of the values and the number of the objectives are" - f" mismatched. Expect {len(directions)}, but actually {len(values)}." - ) - + if len(directions) == 1: + try: + values = [float(ret.return_value)] + except (ValueError, TypeError): + raise ValueError( + f"Return value must be float-castable. Got '{ret.return_value}'." + ).with_traceback(sys.exc_info()[2]) + else: + try: + values = [float(v) for v in ret.return_value] + except (ValueError, TypeError): + raise ValueError( + "Return value must be a list or tuple of float-castable values." + f" Got '{ret.return_value}'." + ).with_traceback(sys.exc_info()[2]) + if len(values) != len(directions): + raise ValueError( + "The number of the values and the number of the objectives are" + f" mismatched. Expect {len(directions)}, but actually {len(values)}." + ) study.tell(trial=trial, state=state, values=values) - except Exception as e: state = optuna.trial.TrialState.FAIL study.tell(trial=trial, state=state, values=values)