diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 23335fd17ef..c53ee9c260b 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -765,7 +765,7 @@ def mark_as( elif status == TrialStatus.ABANDONED: self.mark_abandoned(reason=kwargs.get("reason"), unsafe=unsafe) elif status == TrialStatus.FAILED: - self.mark_failed(unsafe=unsafe) + self.mark_failed(reason=kwargs.get("reason"), unsafe=unsafe) elif status == TrialStatus.COMPLETED: self.mark_completed(unsafe=unsafe) elif status == TrialStatus.EARLY_STOPPED: diff --git a/ax/core/tests/test_trial.py b/ax/core/tests/test_trial.py index f3f95e17bee..87839502817 100644 --- a/ax/core/tests/test_trial.py +++ b/ax/core/tests/test_trial.py @@ -13,6 +13,7 @@ from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.runner import Runner from ax.exceptions.core import UserInputError +from ax.runners.synthetic import SyntheticRunner from ax.utils.common.result import Ok from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -23,7 +24,6 @@ get_test_map_data_experiment, ) - TEST_DATA = Data( df=pd.DataFrame( [ @@ -129,6 +129,16 @@ def test_abandonment(self) -> None: self.assertFalse(self.trial.status.is_failed) self.assertTrue(self.trial.did_not_complete) + def test_failed(self) -> None: + fail_reason = "testing" + self.trial.runner = SyntheticRunner() + self.trial.run() + self.trial.mark_failed(reason=fail_reason) + + self.assertTrue(self.trial.status.is_failed) + self.assertTrue(self.trial.did_not_complete) + self.assertEqual(self.trial.failed_reason, fail_reason) + def test_mark_as(self) -> None: for terminal_status in ( TrialStatus.ABANDONED, @@ -143,13 +153,17 @@ def test_mark_as(self) -> None: if status == TrialStatus.RUNNING: kwargs["no_runner_required"] = True if status == TrialStatus.ABANDONED: - kwargs["reason"] = "test_reason" + kwargs["reason"] = "test_reason_abandon" + if status == TrialStatus.FAILED: + kwargs["reason"] = "test_reason_failed" self.trial.mark_as(status=status, **kwargs) self.assertTrue(self.trial.status == status) if status == TrialStatus.ABANDONED: - self.assertEqual(self.trial.abandoned_reason, "test_reason") - else: + self.assertEqual(self.trial.abandoned_reason, "test_reason_abandon") + self.assertIsNone(self.trial.failed_reason) + elif status == TrialStatus.FAILED: + self.assertEqual(self.trial.failed_reason, "test_reason_failed") self.assertIsNone(self.trial.abandoned_reason) if status != TrialStatus.RUNNING: