From cc89030ddefa0f27369148f61d90cb62f5ce56f8 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Mon, 30 Oct 2023 13:28:06 -0700 Subject: [PATCH] Update MinimumTrialsInStatus to accept a list of statuses to check (#1932) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1932 This updates the MinimumTrialsInStatus criterion class to accept a list of trials to check and a list of trials not to check which enables `only_in` allowing it to be a more flexible class Things in the pipeline: (0) Update maxtrials to accept a list of statuses to check and a list of statuses to not check (1) Use the transition criterion to determine if a node is complete (2) add is_complete to generationNode and then use that in generation Strategy for moving forward (3) When transition_criterion list is empty, unlimited trials can be generated + skip max trial criterion addition if numtrials == -1 (4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode (5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed Reviewed By: lena-kashtelyan Differential Revision: D50608777 fbshipit-source-id: 9ba4a4919efc93d254540ffa0e196c71b1b7970a --- ax/modelbridge/generation_node.py | 2 +- .../tests/test_completion_criterion.py | 2 +- .../tests/test_transition_criterion.py | 29 ++++++++++++++----- ax/modelbridge/transition_criterion.py | 27 +++++++++-------- ax/storage/json_store/decoder.py | 2 +- 5 files changed, 40 insertions(+), 22 deletions(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 775db2657b9..32c30edcbb8 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -521,7 +521,7 @@ def __post_init__(self) -> None: ) transition_criteria.append( MinimumTrialsInStatus( - status=TrialStatus.COMPLETED, + statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=self.min_trials_observed, ) ) diff --git a/ax/modelbridge/tests/test_completion_criterion.py b/ax/modelbridge/tests/test_completion_criterion.py index 4af57c64573..e3d71a86032 100644 --- a/ax/modelbridge/tests/test_completion_criterion.py +++ b/ax/modelbridge/tests/test_completion_criterion.py @@ -76,7 +76,7 @@ def test_single_criterion(self) -> None: def test_many_criteria(self) -> None: criteria = [ MinimumPreferenceOccurances(metric_name="m1", threshold=3), - MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=5), + MinimumTrialsInStatus(statuses=[TrialStatus.COMPLETED], threshold=5), ] experiment = get_experiment() diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 3f6555e9f0d..26572966fdb 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -110,7 +110,7 @@ def test_default_step_criterion_setup(self) -> None: step_0_expected_transition_criteria = [ MaxTrials(threshold=3, enforce=True, transition_to="GenerationStep_1"), MinimumTrialsInStatus( - status=TrialStatus.COMPLETED, + statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=0, transition_to="GenerationStep_1", ), @@ -118,7 +118,7 @@ def test_default_step_criterion_setup(self) -> None: step_1_expected_transition_criteria = [ MaxTrials(threshold=4, enforce=False, transition_to="GenerationStep_2"), MinimumTrialsInStatus( - status=TrialStatus.COMPLETED, + statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=2, transition_to="GenerationStep_2", ), @@ -126,7 +126,7 @@ def test_default_step_criterion_setup(self) -> None: step_2_expected_transition_criteria = [ MaxTrials(threshold=-1, enforce=True, transition_to="GenerationStep_3"), MinimumTrialsInStatus( - status=TrialStatus.COMPLETED, + statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=0, transition_to="GenerationStep_3", ), @@ -192,6 +192,18 @@ def test_minimum_trials_in_status_is_met(self) -> None: .is_met(experiment, gs._steps[0].trials_from_node) ) + # Check mixed status MinimumTrialsInStatus + min_criterion = MinimumTrialsInStatus( + threshold=3, statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED] + ) + self.assertFalse( + min_criterion.is_met(experiment, gs._steps[0].trials_from_node) + ) + for idx, trial in experiment.trials.items(): + if idx == 2: + trial._status = TrialStatus.EARLY_STOPPED + self.assertTrue(min_criterion.is_met(experiment, gs._steps[0].trials_from_node)) + def test_max_trials_is_met(self) -> None: """Test that the is_met method in MaxTrials works""" experiment = get_branin_experiment() @@ -320,7 +332,9 @@ def test_trials_from_node_none(self) -> None: ) # Check MinimumTrialsInStatus - min_criterion = MinimumTrialsInStatus(threshold=3, status=TrialStatus.COMPLETED) + min_criterion = MinimumTrialsInStatus( + threshold=3, statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED] + ) with self.assertLogs(TransitionCriterion.__module__, logging.WARNING) as logger: self.assertFalse(min_criterion.is_met(experiment, trials_from_node=None)) self.assertTrue( @@ -348,14 +362,15 @@ def test_repr(self) -> None: ) minimum_trials_in_status_criterion = MinimumTrialsInStatus( - status=TrialStatus.COMPLETED, + statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED], threshold=0, transition_to="GenerationStep_2", ) self.assertEqual( str(minimum_trials_in_status_criterion), - "MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0," - + " transition_to='GenerationStep_2')", + "MinimumTrialsInStatus(statuses=[, " + + "], threshold=0, transition_to=" + + "'GenerationStep_2')", ) minimum_preference_occurances_criterion = MinimumPreferenceOccurances( diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index 2502ef0f043..dceeac4552d 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -5,7 +5,7 @@ from abc import abstractmethod from logging import Logger -from typing import Optional, Set +from typing import List, Optional, Set from ax.core.base_trial import TrialStatus @@ -58,9 +58,12 @@ class MinimumTrialsInStatus(TransitionCriterion): # TODO: @mgarrard rename to MinTrials and expand functionality to mirror # `MaxTrials` after legacy usecases are updated. def __init__( - self, status: TrialStatus, threshold: int, transition_to: Optional[str] = None + self, + statuses: List[TrialStatus], + threshold: int, + transition_to: Optional[str] = None, ) -> None: - self.status = status + self.statuses = statuses self.threshold = threshold super().__init__(transition_to=transition_to) @@ -73,27 +76,27 @@ def is_met( trials_from_node: A set containing the indices of trials that were generated from this GenerationNode. """ + exp_trials_with_statuses = set() + for status in self.statuses: + exp_trials_with_statuses = exp_trials_with_statuses.union( + experiment.trial_indices_by_status[status] + ) + # Trials from node should not be none for any new GenerationStrategies if trials_from_node is None: logger.warning( "trials_from_node is None, will check threshold on experiment level.", ) - return ( - len(experiment.trial_indices_by_status[self.status]) >= self.threshold - ) + return len(exp_trials_with_statuses) >= self.threshold return ( - len( - trials_from_node.intersection( - experiment.trial_indices_by_status[self.status] - ) - ) + len(trials_from_node.intersection(exp_trials_with_statuses)) >= self.threshold ) def __repr__(self) -> str: """Returns a string representation of MinimumTrialsInStatus.""" return ( - f"{self.__class__.__name__}(status={self.status}, " + f"{self.__class__.__name__}(statuses={self.statuses}, " f"threshold={self.threshold}, transition_to='{self.transition_to}')" ) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 1de3b72ed28..f05b18bc43f 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -355,7 +355,7 @@ def transition_criteria_from_json( if criterion_type == "MinimumTrialsInStatus": criterion_list.append( MinimumTrialsInStatus( - status=object_from_json(criterion_json.pop("status")), + statuses=object_from_json(criterion_json.pop("statuses")), threshold=criterion_json.pop("threshold"), ) )