forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transition criteria to GenerationNode (facebook#1887)
Summary: X-link: facebookresearch/aepsych#320 In this diff we do a few things: (1) Create a TransitionCriterion class: - this class will subsume the CompletionCriterion class and will be a bit more flexible - it has the same child classes + maximumtrialsinstatus subclass, we may add more subclasses or fields later as we further test (2) Create an list of transitioncriterion from the generationstep class: - minimum_trials_observed can be taken care of by the more flexible MinimumTrialsInStatus class - num_trials and enforce_num_trials can be taken care of by the more flexible MaximumTrialsInStatus class (3) adds a doc string to GenNode class - tangential but easy (4) updates the type of completion_criteria of GenerationStep from CompletionCriterion to TransitionCriterion In following diffs we will: (1) add transition criterion to the repr string + some of the other fields that havent made it yet (2) begin moving the functions related to completing the step up to node and leveraging the transition criterion for checks instead of indexes -- this is where we may need to add additional fields to transitioncriterion (3) add doc strings to everywhere in teh GenNode class (4) add additional unit tests to MaxTrials to bring coverage to 100% (5) skip max trial criterion addition if numtrials == -1 (6) clean up compeletion_criterion class once new ax release can be pinned to aepsych version Reviewed By: lena-kashtelyan Differential Revision: D49509997
- Loading branch information
1 parent
994fc89
commit 6fea95d
Showing
10 changed files
with
433 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from unittest.mock import patch | ||
|
||
import pandas as pd | ||
from ax.core.base_trial import TrialStatus | ||
from ax.core.data import Data | ||
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy | ||
from ax.modelbridge.registry import Models | ||
from ax.modelbridge.transition_criterion import ( | ||
MaxTrials, | ||
MinimumPreferenceOccurances, | ||
MinimumTrialsInStatus, | ||
) | ||
from ax.utils.common.testutils import TestCase | ||
from ax.utils.testing.core_stubs import get_experiment | ||
|
||
|
||
class TestTransitionCriterion(TestCase): | ||
def test_minimum_preference_criterion(self) -> None: | ||
"""Tests the minimum preference criterion subcalss of TransitionCriterion.""" | ||
criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3) | ||
experiment = get_experiment() | ||
generation_strategy = GenerationStrategy( | ||
name="SOBOL+GPEI::default", | ||
steps=[ | ||
GenerationStep( | ||
model=Models.SOBOL, | ||
num_trials=-1, | ||
completion_criteria=[criterion], | ||
), | ||
GenerationStep( | ||
model=Models.GPEI, | ||
num_trials=-1, | ||
max_parallelism=1, | ||
), | ||
], | ||
) | ||
generation_strategy.experiment = experiment | ||
|
||
# Has not seen enough of each preference | ||
self.assertFalse( | ||
generation_strategy._maybe_move_to_next_step( | ||
raise_data_required_error=False | ||
) | ||
) | ||
|
||
data = Data( | ||
df=pd.DataFrame( | ||
{ | ||
"trial_index": range(6), | ||
"arm_name": [f"{i}_0" for i in range(6)], | ||
"metric_name": ["m1" for _ in range(6)], | ||
"mean": [0, 0, 0, 1, 1, 1], | ||
"sem": [0 for _ in range(6)], | ||
} | ||
) | ||
) | ||
with patch.object(experiment, "fetch_data", return_value=data): | ||
# We have seen three "yes" and three "no" | ||
self.assertTrue( | ||
generation_strategy._maybe_move_to_next_step( | ||
raise_data_required_error=False | ||
) | ||
) | ||
self.assertEqual(generation_strategy._curr.model, Models.GPEI) | ||
|
||
def test_default_step_criterion_setup(self) -> None: | ||
"""This test ensures that the default completion criterion for GenerationSteps | ||
is set as expected. | ||
The default completion criterion is to create two TransitionCriterion, one | ||
of type `MaximumTrialsInStatus` and one of type `MinimumTrialsInStatus`. | ||
These are constructed via the inputs of `num_trials`, `enforce_num_trials`, | ||
and `minimum_trials_observed` on the GenerationStep. | ||
""" | ||
experiment = get_experiment() | ||
gs = GenerationStrategy( | ||
name="SOBOL+GPEI::default", | ||
steps=[ | ||
GenerationStep( | ||
model=Models.SOBOL, | ||
num_trials=3, | ||
enforce_num_trials=False, | ||
), | ||
GenerationStep( | ||
model=Models.GPEI, | ||
num_trials=4, | ||
max_parallelism=1, | ||
min_trials_observed=2, | ||
), | ||
GenerationStep( | ||
model=Models.GPEI, | ||
num_trials=-1, | ||
max_parallelism=1, | ||
), | ||
], | ||
) | ||
gs.experiment = experiment | ||
|
||
step_0_expected_transition_criteria = [ | ||
MaxTrials(threshold=3, enforce=False), | ||
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0), | ||
] | ||
step_1_expected_transition_criteria = [ | ||
MaxTrials(threshold=4, enforce=True), | ||
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=2), | ||
] | ||
step_2_expected_transition_criteria = [ | ||
MaxTrials(threshold=-1, enforce=True), | ||
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0), | ||
] | ||
self.assertEqual( | ||
gs._steps[0].transition_criteria, step_0_expected_transition_criteria | ||
) | ||
self.assertEqual( | ||
gs._steps[1].transition_criteria, step_1_expected_transition_criteria | ||
) | ||
self.assertEqual( | ||
gs._steps[2].transition_criteria, step_2_expected_transition_criteria | ||
) | ||
|
||
# Check default results for `is_met` call | ||
self.assertTrue(gs._steps[0].transition_criteria[0].is_met(experiment)) | ||
self.assertTrue(gs._steps[0].transition_criteria[1].is_met(experiment)) | ||
self.assertFalse(gs._steps[1].transition_criteria[0].is_met(experiment)) | ||
self.assertFalse(gs._steps[1].transition_criteria[1].is_met(experiment)) | ||
|
||
def test_max_trials_status_arg(self) -> None: | ||
"""Tests the `only_in_status` argument checks the threshold based on the | ||
number of trials in specified status instead of all trials (which is the | ||
default behavior). | ||
""" | ||
experiment = get_experiment() | ||
criterion = MaxTrials( | ||
threshold=5, only_in_status=TrialStatus.RUNNING, enforce=True | ||
) | ||
self.assertFalse(criterion.is_met(experiment)) |
Oops, something went wrong.