Skip to content

Commit

Permalink
Update MaxTrials to accept list of statuses to check and avoid (#1933)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1933

This updates the MaxTrials criterion class to accept a list of trials to check and a list of trials not to check which enables `only_in` and `not_in` functionality for MaxTrials allowing it to be a more flexible class.

Things in the pipeline:
(0) fix json decoding of `transition_to` argument
(1) Use the transition criterion to determine if a node is complete
(2) add is_complete to generationNode and then use that in generation Strategy for moving forward
(3) When transition_criterion list is empty, unlimited trials can be generated + skip max trial criterion addition if numtrials == -1
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed

Reviewed By: lena-kashtelyan

Differential Revision: D50622912

fbshipit-source-id: d0c7c5f189659e4d7dc1f1b3382773bb0bf3fcbd
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 30, 2023
1 parent cc89030 commit b62cbe7
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 50 deletions.
73 changes: 54 additions & 19 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,43 @@ def test_max_trials_is_met(self) -> None:
)
)

# Check not in statuses and only in statuses
max_criterion_not_in_statuses = MaxTrials(
threshold=2, enforce=True, not_in_statuses=[TrialStatus.COMPLETED]
)
max_criterion_only_statuses = MaxTrials(
threshold=2,
enforce=True,
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
)
# experiment currently has 4 trials, but none of them are completed
self.assertTrue(
max_criterion_not_in_statuses.is_met(
experiment, trials_from_node=gs._steps[0].trials_from_node
)
)
self.assertFalse(
max_criterion_only_statuses.is_met(
experiment, trials_from_node=gs._steps[0].trials_from_node
)
)

# set 3 of the 4 trials to status == completed
for _idx, trial in experiment.trials.items():
trial._status = TrialStatus.COMPLETED
if _idx == 2:
break
self.assertTrue(
max_criterion_only_statuses.is_met(
experiment, trials_from_node=gs._steps[0].trials_from_node
)
)
self.assertFalse(
max_criterion_not_in_statuses.is_met(
experiment, trials_from_node=gs._steps[0].trials_from_node
)
)

# if num_trials == -1, should always pass
self.assertTrue(
gs._steps[2]
Expand All @@ -271,17 +308,6 @@ def test_max_trials_is_met(self) -> None:
)
)

def test_max_trials_status_arg(self) -> None:
"""Tests the `only_in_status` argument checks the threshold based on the
number of trials in specified status instead of all trials (which is the
default behavior).
"""
experiment = get_experiment()
criterion = MaxTrials(
threshold=5, only_in_status=TrialStatus.RUNNING, enforce=True
)
self.assertFalse(criterion.is_met(experiment, trials_from_node={2, 3}))

def test_trials_from_node_none(self) -> None:
"""Tests MinimumTrialsInStatus and MaxTrials default to experiment
level trials when trials_from_node is None.
Expand All @@ -301,18 +327,24 @@ def test_trials_from_node_none(self) -> None:
],
)
max_criterion_with_status = MaxTrials(
threshold=2, enforce=True, only_in_status=TrialStatus.COMPLETED
threshold=2, enforce=True, only_in_statuses=[TrialStatus.COMPLETED]
)
max_criterion = MaxTrials(threshold=2, enforce=True)
warning_msg = (
"trials_from_node is None, will check threshold on experiment level"
warning_msg_max = (
"`trials_from_node` is None, will check threshold"
+ " on experiment level for MaxTrials."
)

warning_msg_min = (
"`trials_from_node` is None, will check threshold on"
+ " experiment level for MinimumTrialsInStatus."
)

# no trials so criterion should be false, then add trials to pass criterion
with self.assertLogs(TransitionCriterion.__module__, logging.WARNING) as logger:
self.assertFalse(max_criterion.is_met(experiment, trials_from_node=None))
self.assertTrue(
any(warning_msg in output for output in logger.output),
any(warning_msg_max in output for output in logger.output),
logger.output,
)
for _i in range(3):
Expand All @@ -338,7 +370,7 @@ def test_trials_from_node_none(self) -> None:
with self.assertLogs(TransitionCriterion.__module__, logging.WARNING) as logger:
self.assertFalse(min_criterion.is_met(experiment, trials_from_node=None))
self.assertTrue(
any(warning_msg in output for output in logger.output),
any(warning_msg_min in output for output in logger.output),
logger.output,
)
for _idx, trial in experiment.trials.items():
Expand All @@ -349,16 +381,19 @@ def test_repr(self) -> None:
"""Tests that the repr string is correctly formatted for all
TransitionCriterion child classes.
"""
self.maxDiff = None
max_trials_criterion = MaxTrials(
threshold=5,
enforce=True,
transition_to="GenerationStep_1",
only_in_status=TrialStatus.COMPLETED,
only_in_statuses=[TrialStatus.COMPLETED],
not_in_statuses=[TrialStatus.FAILED],
)
self.assertEqual(
str(max_trials_criterion),
"MaxTrials(threshold=5, enforce=True, only_in_status="
+ "TrialStatus.COMPLETED, transition_to='GenerationStep_1')",
"MaxTrials(threshold=5, enforce=True, only_in_statuses="
+ "[<TrialStatus.COMPLETED: 3>], transition_to='GenerationStep_1',"
+ " not_in_statuses=[<TrialStatus.FAILED: 2>])",
)

minimum_trials_in_status_criterion = MinimumTrialsInStatus(
Expand Down
76 changes: 48 additions & 28 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def is_met(
) -> bool:
pass

def experiment_trials_by_status(
self, experiment: Experiment, statuses: List[TrialStatus]
) -> Set[int]:
"""Get the trial indices from the experiment with the desired statuses.
Args:
experiment: The experiment associated with this GenerationStrategy.
statuses: The statuses to filter on.
Returns:
The trial indices in the experiment with the desired statuses.
"""
exp_trials_with_statuses = set()
for status in statuses:
exp_trials_with_statuses = exp_trials_with_statuses.union(
experiment.trial_indices_by_status[status]
)
return exp_trials_with_statuses


class MinimumTrialsInStatus(TransitionCriterion):
"""
Expand Down Expand Up @@ -76,16 +94,15 @@ def is_met(
trials_from_node: A set containing the indices of trials that were
generated from this GenerationNode.
"""
exp_trials_with_statuses = set()
for status in self.statuses:
exp_trials_with_statuses = exp_trials_with_statuses.union(
experiment.trial_indices_by_status[status]
)
exp_trials_with_statuses = self.experiment_trials_by_status(
experiment, self.statuses
)

# Trials from node should not be none for any new GenerationStrategies
if trials_from_node is None:
logger.warning(
"trials_from_node is None, will check threshold on experiment level.",
"`trials_from_node` is None, will check threshold on"
+ " experiment level for MinimumTrialsInStatus.",
)
return len(exp_trials_with_statuses) >= self.threshold
return (
Expand Down Expand Up @@ -117,12 +134,14 @@ def __init__(
self,
threshold: int,
enforce: bool,
only_in_status: Optional[TrialStatus] = None,
only_in_statuses: Optional[List[TrialStatus]] = None,
transition_to: Optional[str] = None,
not_in_statuses: Optional[List[TrialStatus]] = None,
) -> None:
self.threshold = threshold
self.enforce = enforce
self.only_in_status = only_in_status
self.only_in_statuses = only_in_statuses
self.not_in_statuses = not_in_statuses
super().__init__(transition_to=transition_to)

def is_met(
Expand All @@ -134,37 +153,38 @@ def is_met(
trials_from_node: A set containing the indices of trials that were
generated from this GenerationNode.
"""
# TODO: @mgarrard fix enforce logic
if self.enforce:
trials_to_check = experiment.trials.keys()
# limit the trials to only those in the specified statuses
if self.only_in_statuses is not None:
trials_to_check = self.experiment_trials_by_status(
experiment, self.only_in_statuses
)
# exclude the trials to those not in the specified statuses
if self.not_in_statuses is not None:
trials_to_check -= self.experiment_trials_by_status(
experiment, self.not_in_statuses
)

# Trials from node should not be none for any new GenerationStrategies
if trials_from_node is None:
logger.warning(
"trials_from_node is None, will check threshold on"
+ " experiment level.",
"`trials_from_node` is None, will check threshold on"
+ " experiment level for MaxTrials.",
)
if self.only_in_status is not None:
return (
len(experiment.trial_indices_by_status[self.only_in_status])
>= self.threshold
)
return len(experiment.trials) >= self.threshold
if self.only_in_status is not None:
return (
len(
trials_from_node.intersection(
experiment.trial_indices_by_status[self.only_in_status]
)
)
>= self.threshold
)
return len(trials_from_node) >= self.threshold
return len(trials_to_check) >= self.threshold

return len(trials_from_node.intersection(trials_to_check)) >= self.threshold
return True

def __repr__(self) -> str:
"""Returns a string representation of MaxTrials."""
return (
f"{self.__class__.__name__}(threshold={self.threshold}, "
f"enforce={self.enforce}, only_in_status={self.only_in_status}, "
f"transition_to='{self.transition_to}')"
f"enforce={self.enforce}, only_in_statuses={self.only_in_statuses}, "
f"transition_to='{self.transition_to}', "
f"not_in_statuses={self.not_in_statuses})"
)


Expand Down
11 changes: 8 additions & 3 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,18 @@ def transition_criteria_from_json(
elif criterion_type == "MaxTrials":
criterion_list.append(
MaxTrials(
only_in_status=object_from_json(
criterion_json.pop("only_in_status")
only_in_statuses=object_from_json(
criterion_json.pop("only_in_statuses")
)
if "only_in_status" in criterion_json.keys()
if "only_in_statuses" in criterion_json.keys()
else None,
threshold=criterion_json.pop("threshold"),
enforce=criterion_json.pop("enforce"),
not_in_statuses=object_from_json(
criterion_json.pop("not_in_statuses")
)
if "not_in_statuses" in criterion_json.keys()
else None,
)
)
else:
Expand Down

0 comments on commit b62cbe7

Please sign in to comment.