Skip to content

Commit

Permalink
Update MinimumTrialsInStatus to accept a list of statuses to check (#…
Browse files Browse the repository at this point in the history
…1932)

Summary:
Pull Request resolved: #1932

This updates the MinimumTrialsInStatus criterion class to accept a list of trials to check and a list of trials not to check which enables `only_in` allowing it to be a more flexible class

Things in the pipeline:
(0) Update maxtrials to accept a list of statuses to check and a list of statuses to not check
(1) Use the transition criterion to determine if a node is complete
(2) add is_complete to generationNode and then use that in generation Strategy for moving forward
(3) When transition_criterion list is empty, unlimited trials can be generated + skip max trial criterion addition if numtrials == -1
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed

Reviewed By: lena-kashtelyan

Differential Revision: D50608777

fbshipit-source-id: 9ba4a4919efc93d254540ffa0e196c71b1b7970a
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 30, 2023
1 parent 80eb384 commit cc89030
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def __post_init__(self) -> None:
)
transition_criteria.append(
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=self.min_trials_observed,
)
)
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_single_criterion(self) -> None:
def test_many_criteria(self) -> None:
criteria = [
MinimumPreferenceOccurances(metric_name="m1", threshold=3),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=5),
MinimumTrialsInStatus(statuses=[TrialStatus.COMPLETED], threshold=5),
]

experiment = get_experiment()
Expand Down
29 changes: 22 additions & 7 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ def test_default_step_criterion_setup(self) -> None:
step_0_expected_transition_criteria = [
MaxTrials(threshold=3, enforce=True, transition_to="GenerationStep_1"),
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=0,
transition_to="GenerationStep_1",
),
]
step_1_expected_transition_criteria = [
MaxTrials(threshold=4, enforce=False, transition_to="GenerationStep_2"),
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=2,
transition_to="GenerationStep_2",
),
]
step_2_expected_transition_criteria = [
MaxTrials(threshold=-1, enforce=True, transition_to="GenerationStep_3"),
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=0,
transition_to="GenerationStep_3",
),
Expand Down Expand Up @@ -192,6 +192,18 @@ def test_minimum_trials_in_status_is_met(self) -> None:
.is_met(experiment, gs._steps[0].trials_from_node)
)

# Check mixed status MinimumTrialsInStatus
min_criterion = MinimumTrialsInStatus(
threshold=3, statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED]
)
self.assertFalse(
min_criterion.is_met(experiment, gs._steps[0].trials_from_node)
)
for idx, trial in experiment.trials.items():
if idx == 2:
trial._status = TrialStatus.EARLY_STOPPED
self.assertTrue(min_criterion.is_met(experiment, gs._steps[0].trials_from_node))

def test_max_trials_is_met(self) -> None:
"""Test that the is_met method in MaxTrials works"""
experiment = get_branin_experiment()
Expand Down Expand Up @@ -320,7 +332,9 @@ def test_trials_from_node_none(self) -> None:
)

# Check MinimumTrialsInStatus
min_criterion = MinimumTrialsInStatus(threshold=3, status=TrialStatus.COMPLETED)
min_criterion = MinimumTrialsInStatus(
threshold=3, statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED]
)
with self.assertLogs(TransitionCriterion.__module__, logging.WARNING) as logger:
self.assertFalse(min_criterion.is_met(experiment, trials_from_node=None))
self.assertTrue(
Expand Down Expand Up @@ -348,14 +362,15 @@ def test_repr(self) -> None:
)

minimum_trials_in_status_criterion = MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=0,
transition_to="GenerationStep_2",
)
self.assertEqual(
str(minimum_trials_in_status_criterion),
"MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0,"
+ " transition_to='GenerationStep_2')",
"MinimumTrialsInStatus(statuses=[<TrialStatus.COMPLETED: 3>, "
+ "<TrialStatus.EARLY_STOPPED: 7>], threshold=0, transition_to="
+ "'GenerationStep_2')",
)

minimum_preference_occurances_criterion = MinimumPreferenceOccurances(
Expand Down
27 changes: 15 additions & 12 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import abstractmethod
from logging import Logger
from typing import Optional, Set
from typing import List, Optional, Set

from ax.core.base_trial import TrialStatus

Expand Down Expand Up @@ -58,9 +58,12 @@ class MinimumTrialsInStatus(TransitionCriterion):
# TODO: @mgarrard rename to MinTrials and expand functionality to mirror
# `MaxTrials` after legacy usecases are updated.
def __init__(
self, status: TrialStatus, threshold: int, transition_to: Optional[str] = None
self,
statuses: List[TrialStatus],
threshold: int,
transition_to: Optional[str] = None,
) -> None:
self.status = status
self.statuses = statuses
self.threshold = threshold
super().__init__(transition_to=transition_to)

Expand All @@ -73,27 +76,27 @@ def is_met(
trials_from_node: A set containing the indices of trials that were
generated from this GenerationNode.
"""
exp_trials_with_statuses = set()
for status in self.statuses:
exp_trials_with_statuses = exp_trials_with_statuses.union(
experiment.trial_indices_by_status[status]
)

# Trials from node should not be none for any new GenerationStrategies
if trials_from_node is None:
logger.warning(
"trials_from_node is None, will check threshold on experiment level.",
)
return (
len(experiment.trial_indices_by_status[self.status]) >= self.threshold
)
return len(exp_trials_with_statuses) >= self.threshold
return (
len(
trials_from_node.intersection(
experiment.trial_indices_by_status[self.status]
)
)
len(trials_from_node.intersection(exp_trials_with_statuses))
>= self.threshold
)

def __repr__(self) -> str:
"""Returns a string representation of MinimumTrialsInStatus."""
return (
f"{self.__class__.__name__}(status={self.status}, "
f"{self.__class__.__name__}(statuses={self.statuses}, "
f"threshold={self.threshold}, transition_to='{self.transition_to}')"
)

Expand Down
2 changes: 1 addition & 1 deletion ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def transition_criteria_from_json(
if criterion_type == "MinimumTrialsInStatus":
criterion_list.append(
MinimumTrialsInStatus(
status=object_from_json(criterion_json.pop("status")),
statuses=object_from_json(criterion_json.pop("statuses")),
threshold=criterion_json.pop("threshold"),
)
)
Expand Down

0 comments on commit cc89030

Please sign in to comment.