Skip to content

Commit

Permalink
Add transition_to argument to TransitionCriterion
Browse files Browse the repository at this point in the history
Summary:
It is important that the transition criteria can tell the GenerationStrategy which node to move to once the criteria is met. This diff adds the transition_to field to TransitionCriterion


Things in the pipeline:
(1) Update the transition criterion class to check on a per node basis, instead of per experiment
(2) Use the transition criterion to determine if a node is complete
(3) add is_complete to generationNode and then use that in generation Strategy for moving forward
(4) [Mby] skip max trial criterion addition if numtrials == -1
(5) add transition criterion to the repr string + some of the other fields that havent made it yet
(6) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed

Reviewed By: lena-kashtelyan

Differential Revision: D50295684
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Oct 19, 2023
1 parent 99f9ada commit 402e63c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
6 changes: 4 additions & 2 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ def __post_init__(self) -> None:
# Create transition criteria for this step. MaximumTrialsInStatus can be used
# to ensure that requirements related to num_trials and enforce_num_trials
# are met. MinimumTrialsInStatus can be used enforce the min_trials_observed
# requirement.
# requirement. We set transition_to on GenerationStrategy instead of here as
# GenerationStrategy can see the full list of steps.
transition_criteria = []
transition_criteria.append(
MaxTrials(
Expand All @@ -499,7 +500,8 @@ def __post_init__(self) -> None:
)
transition_criteria.append(
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED, threshold=self.min_trials_observed
status=TrialStatus.COMPLETED,
threshold=self.min_trials_observed,
)
)
transition_criteria += self.completion_criteria
Expand Down
7 changes: 7 additions & 0 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N
# uniqueness is gaurenteed for steps currently due to list structure.
step._node_name = f"GenerationStep_{str(idx)}"
step.index = idx

# Set transition_to field for all but the last step, which remains null.
if idx != len(self._steps):
for transition_criteria in step.transition_criteria:
transition_criteria._transition_to = (
f"GenerationStep_{str(idx + 1)}"
)
step._generation_strategy = self
if not isinstance(step.model, ModelRegistryBase):
self._uses_registered_models = False
Expand Down
24 changes: 18 additions & 6 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,28 @@ def test_default_step_criterion_setup(self) -> None:
gs.experiment = experiment

step_0_expected_transition_criteria = [
MaxTrials(threshold=3, enforce=False),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0),
MaxTrials(threshold=3, enforce=False, transition_to="GenerationStep_1"),
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
threshold=0,
transition_to="GenerationStep_1",
),
]
step_1_expected_transition_criteria = [
MaxTrials(threshold=4, enforce=True),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=2),
MaxTrials(threshold=4, enforce=True, transition_to="GenerationStep_2"),
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
threshold=2,
transition_to="GenerationStep_2",
),
]
step_2_expected_transition_criteria = [
MaxTrials(threshold=-1, enforce=True),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0),
MaxTrials(threshold=-1, enforce=True, transition_to="GenerationStep_3"),
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED,
threshold=0,
transition_to="GenerationStep_3",
),
]
self.assertEqual(
gs._steps[0].transition_criteria, step_0_expected_transition_criteria
Expand Down
37 changes: 32 additions & 5 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,43 @@
# LICENSE file in the root directory of this source tree.

from abc import abstractmethod
from logging import Logger
from typing import Optional

from ax.core.base_trial import TrialStatus

from ax.core.experiment import Experiment
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import SerializationMixin

logger: Logger = get_logger(__name__)


class TransitionCriterion(Base, SerializationMixin):
"""
Simple class to descibe a condition which must be met for a GenerationStrategy
to move to its next GenerationNode.
Args:
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met.
"""

# TODO: @mgarrard add `transition_to` attribute to define the next node
def __init__(self) -> None:
pass
_transition_to: Optional[str] = None

def __init__(self, transition_to: Optional[str] = None) -> None:
self._transition_to = transition_to

@property
def transition_to(self) -> Optional[str]:
"""The name of the next GenerationNode after this TransitionCriterion is
completed. Warns if unset.
"""
if self._transition_to is None:
logger.warning("No transition_to specified on this TransitionCriterion")

return self._transition_to

@abstractmethod
def is_met(self, experiment: Experiment) -> bool:
Expand All @@ -34,9 +53,12 @@ class MinimumTrialsInStatus(TransitionCriterion):
GenerationStrategy experiment has reached a certain threshold.
"""

def __init__(self, status: TrialStatus, threshold: int) -> None:
def __init__(
self, status: TrialStatus, threshold: int, transition_to: Optional[str] = None
) -> None:
self.status = status
self.threshold = threshold
super().__init__(transition_to=transition_to)

def is_met(self, experiment: Experiment) -> bool:
return len(experiment.trial_indices_by_status[self.status]) >= self.threshold
Expand All @@ -59,11 +81,13 @@ def __init__(
threshold: int,
enforce: bool,
only_in_status: Optional[TrialStatus] = None,
transition_to: Optional[str] = None,
) -> None:
self.threshold = threshold
self.enforce = enforce
# Optional argument for specifying only checking trials with this status
self.only_in_status = only_in_status
super().__init__(transition_to=transition_to)

def is_met(self, experiment: Experiment) -> bool:
if self.enforce:
Expand All @@ -83,9 +107,12 @@ class MinimumPreferenceOccurances(TransitionCriterion):
responses have been received.
"""

def __init__(self, metric_name: str, threshold: int) -> None:
def __init__(
self, metric_name: str, threshold: int, transition_to: Optional[str] = None
) -> None:
self.metric_name = metric_name
self.threshold = threshold
super().__init__(transition_to=transition_to)

def is_met(self, experiment: Experiment) -> bool:
# TODO: @mgarrard replace fetch_data with lookup_data
Expand Down

0 comments on commit 402e63c

Please sign in to comment.