Skip to content

Commit

Permalink
Update BaseTrial mark_as method to mark failed_reason (#1889)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1889

Update BaseTrial mark_as to change failed_reason and update unit tests to test for failed_reason

Reviewed By: mpolson64

Differential Revision: D49853374

fbshipit-source-id: 7cfc1f8e2713f40345e56ba1475522dc0235b0e3
  • Loading branch information
eonofrey authored and facebook-github-bot committed Oct 3, 2023
1 parent e67a2e2 commit 994fc89
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def mark_as(
elif status == TrialStatus.ABANDONED:
self.mark_abandoned(reason=kwargs.get("reason"), unsafe=unsafe)
elif status == TrialStatus.FAILED:
self.mark_failed(unsafe=unsafe)
self.mark_failed(reason=kwargs.get("reason"), unsafe=unsafe)
elif status == TrialStatus.COMPLETED:
self.mark_completed(unsafe=unsafe)
elif status == TrialStatus.EARLY_STOPPED:
Expand Down
22 changes: 18 additions & 4 deletions ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.runner import Runner
from ax.exceptions.core import UserInputError
from ax.runners.synthetic import SyntheticRunner
from ax.utils.common.result import Ok
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand All @@ -23,7 +24,6 @@
get_test_map_data_experiment,
)


TEST_DATA = Data(
df=pd.DataFrame(
[
Expand Down Expand Up @@ -129,6 +129,16 @@ def test_abandonment(self) -> None:
self.assertFalse(self.trial.status.is_failed)
self.assertTrue(self.trial.did_not_complete)

def test_failed(self) -> None:
fail_reason = "testing"
self.trial.runner = SyntheticRunner()
self.trial.run()
self.trial.mark_failed(reason=fail_reason)

self.assertTrue(self.trial.status.is_failed)
self.assertTrue(self.trial.did_not_complete)
self.assertEqual(self.trial.failed_reason, fail_reason)

def test_mark_as(self) -> None:
for terminal_status in (
TrialStatus.ABANDONED,
Expand All @@ -143,13 +153,17 @@ def test_mark_as(self) -> None:
if status == TrialStatus.RUNNING:
kwargs["no_runner_required"] = True
if status == TrialStatus.ABANDONED:
kwargs["reason"] = "test_reason"
kwargs["reason"] = "test_reason_abandon"
if status == TrialStatus.FAILED:
kwargs["reason"] = "test_reason_failed"
self.trial.mark_as(status=status, **kwargs)
self.assertTrue(self.trial.status == status)

if status == TrialStatus.ABANDONED:
self.assertEqual(self.trial.abandoned_reason, "test_reason")
else:
self.assertEqual(self.trial.abandoned_reason, "test_reason_abandon")
self.assertIsNone(self.trial.failed_reason)
elif status == TrialStatus.FAILED:
self.assertEqual(self.trial.failed_reason, "test_reason_failed")
self.assertIsNone(self.trial.abandoned_reason)

if status != TrialStatus.RUNNING:
Expand Down

0 comments on commit 994fc89

Please sign in to comment.