Skip to content

Commit

Permalink
Add transition criteria to GenerationNode
Browse files Browse the repository at this point in the history
Summary:
In this diff we do a few things:
(1) Create a TransitionCriterion class:
- this class will subsume the CompletionCriterion class and will be a bit more flexible
- it has the same child classes + maximumtrialsinstatus subclass, we may add more subclasses or fields later as we further test

(2) Create an list of transitioncriterion from the generationstep class:
- minimum_trials_observed can be taken care of by the more flexible MinimumTrialsInStatus class
- num_trials and enforce_num_trials can be taken care of by the more flexible MaximumTrialsInStatus class

(3) adds a doc string to GenNode class - tangential but easy

(4) updates the type of completion_criteria of GenerationStep from CompletionCriterion to TransitionCriterion

(5) clean up compeletion_criterion class bc we don't need it anymore

In following diffs we will:
(1) add transition criterion to the repr string + some of the other fields that havent made it yet
(2) begin moving the functions related to completing the step up to node and leveraging the transition criterion for checks instead of indexes -- this is where we may need to add additional fields to transitioncriterion
(3) add doc strings to everywhere in teh GenNode class
(4) add additional unit tests to MaxTrials to bring coverage to 100%
(5) skip max trial criterion addition if numtrials == -1

Reviewed By: lena-kashtelyan

Differential Revision: D49509997
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Oct 3, 2023
1 parent e67a2e2 commit 9aa4a29
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 132 deletions.
49 changes: 0 additions & 49 deletions ax/modelbridge/completion_criterion.py

This file was deleted.

67 changes: 62 additions & 5 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
MaxParallelismReachedException,
)
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.completion_criterion import CompletionCriterion
from ax.modelbridge.cross_validation import BestModelSelector, CVDiagnostics, CVResult
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import ModelRegistryBase
from ax.modelbridge.transition_criterion import (
MaxTrials,
MinimumTrialsInStatus,
TransitionCriterion,
)
from ax.utils.common.base import Base, SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
Expand All @@ -57,15 +61,42 @@


class GenerationNode:
"""Base class for generation node, capable of fitting one or more
model specs under the hood and generating candidates from them.
"""Base class for GenerationNode, capable of fitting one or more model specs under
the hood and generating candidates from them.
Args:
model_specs: A list of ModelSpecs to be selected from for generation in this
GenerationNode
should_deduplicate: Whether to deduplicate the parameters of proposed arms
against those of previous arms via rejection sampling. If this is True,
the GenerationStrategy will discard generator runs produced from the
GenerationNode that has `should_deduplicate=True` if they contain arms
already present on the experiment and replace them with new generator runs.
If no generator run with entirely unique arms could be produced in 5
attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we
assume that the optimization converged when the model can no longer suggest
unique arms.
node_name: A unique name for the GenerationNode. Used for storage purposes.
transition_criteria: List of TransitionCriterion, each of which describes a
condition that must be met before completing a GenerationNode. All `is_met`
must evaluateTrue for the GenerationStrategy to move on to the next
GenerationNode.
Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
on explaining the logic of GenerationStep and GenerationStrategy.
"""

# Required options:
model_specs: List[ModelSpec]
# TODO: Move `should_deduplicate` to `ModelSpec` if possible, and make optional
should_deduplicate: bool
_node_name: str

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
use_update: bool = False
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
# attribute is set in generation strategies class
Expand All @@ -80,6 +111,7 @@ def __init__(
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
use_update: bool = False,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
) -> None:
self._node_name = node_name
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
Expand All @@ -91,6 +123,7 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self.use_update = use_update
self._transition_criteria = transition_criteria

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -165,6 +198,10 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat
)
return not_none(self._generation_strategy)

@property
def transition_criteria(self) -> Sequence[TransitionCriterion]:
return not_none(self._transition_criteria)

@property
def experiment(self) -> Experiment:
return self.generation_strategy.experiment
Expand Down Expand Up @@ -450,7 +487,7 @@ class GenerationStep(GenerationNode, SortableBase):
model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the
step's model's `gen` under the hood; `model_gen_kwargs` will be passed to
the model's `gen` like so: `model.gen(**model_gen_kwargs)`.
completion_criteria: List of CompletionCriterion. All `is_met` must evaluate
completion_criteria: List of TransitionCriterion. All `is_met` must evaluate
True for the GenerationStrategy to move on to the next Step
index: Index of this generation step, for use internally in `Generation
Strategy`. Do not assign as it will be reassigned when instantiating
Expand Down Expand Up @@ -481,7 +518,7 @@ class GenerationStep(GenerationNode, SortableBase):
model_gen_kwargs: Optional[Dict[str, Any]] = None

# Optional specifications for use in generation strategy:
completion_criteria: Sequence[CompletionCriterion] = field(default_factory=list)
completion_criteria: Sequence[TransitionCriterion] = field(default_factory=list)
min_trials_observed: int = 0
max_parallelism: Optional[int] = None
use_update: bool = False
Expand Down Expand Up @@ -534,11 +571,31 @@ def __post_init__(self) -> None:
except TypeError:
# Factory functions may not always have a model key defined.
self.model_name = f"Unknown {model_spec.__class__.__name__}"

# Create transition criteria for this step. MaximumTrialsInStatus can be used
# to ensure that requirements related to num_trials and enforce_num_trials
# are met. MinimumTrialsInStatus can be used enforce the min_trials_observed
# requirement.
transition_criteria = []
transition_criteria.append(
MaxTrials(
enforce=self.enforce_num_trials,
threshold=self.num_trials,
)
)
transition_criteria.append(
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED, threshold=self.min_trials_observed
)
)
transition_criteria += self.completion_criteria
# need to unwrap old completion_criteria
super().__init__(
node_name=f"GenerationStep_{str(self.index)}",
model_specs=[model_spec],
should_deduplicate=self.should_deduplicate,
use_update=self.use_update,
transition_criteria=transition_criteria,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.modelbridge.completion_criterion import (
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import (
MaxTrials,
MinimumPreferenceOccurances,
MinimumTrialsInStatus,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment


class TestCompletionCritereon(TestCase):
def test_single_criterion(self) -> None:
class TestTransitionCriterion(TestCase):
def test_minimum_preference_criterion(self) -> None:
criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3)

experiment = get_experiment()

generation_strategy = GenerationStrategy(
name="SOBOL+GPEI::default",
steps=[
GenerationStep(
model=Models.SOBOL, num_trials=-1, completion_criteria=[criterion]
model=Models.SOBOL,
num_trials=-1,
completion_criteria=[criterion],
),
GenerationStep(
model=Models.GPEI,
Expand Down Expand Up @@ -64,22 +65,31 @@ def test_single_criterion(self) -> None:
raise_data_required_error=False
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)

def test_many_criteria(self) -> None:
criteria = [
MinimumPreferenceOccurances(metric_name="m1", threshold=3),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=5),
]
def test_default_step_criterion_setup(self) -> None:
"""This test ensures that the default completion criterion for GenerationSteps
is set as expected.
The default completion criterion is to create two TransitionCriterion, one
of type `MaximumTrialsInStatus` and one of type `MinimumTrialsInStatus`.
These are constructed via the inputs of `num_trials`, `enforce_num_trials`,
and `minimum_trials_observed` on the GenerationStep.
"""
experiment = get_experiment()

generation_strategy = GenerationStrategy(
gs = GenerationStrategy(
name="SOBOL+GPEI::default",
steps=[
GenerationStep(
model=Models.SOBOL, num_trials=-1, completion_criteria=criteria
model=Models.SOBOL,
num_trials=3,
enforce_num_trials=False,
),
GenerationStep(
model=Models.GPEI,
num_trials=4,
max_parallelism=1,
min_trials_observed=2,
),
GenerationStep(
model=Models.GPEI,
Expand All @@ -88,51 +98,32 @@ def test_many_criteria(self) -> None:
),
],
)
generation_strategy.experiment = experiment
gs.experiment = experiment

# Has not seen enough of each preference
self.assertFalse(
generation_strategy._maybe_move_to_next_step(
raise_data_required_error=False
)
step_0_expected_transition_criteria = [
MaxTrials(threshold=3, enforce=False),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0),
]
step_1_expected_transition_criteria = [
MaxTrials(threshold=4, enforce=True),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=2),
]
step_2_expected_transition_criteria = [
MaxTrials(threshold=-1, enforce=True),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0),
]
self.assertEqual(
gs._steps[0].transition_criteria, step_0_expected_transition_criteria
)

data = Data(
df=pd.DataFrame(
{
"trial_index": range(6),
"arm_name": [f"{i}_0" for i in range(6)],
"metric_name": ["m1" for _ in range(6)],
"mean": [0, 0, 0, 1, 1, 1],
"sem": [0 for _ in range(6)],
}
)
self.assertEqual(
gs._steps[1].transition_criteria, step_1_expected_transition_criteria
)
with patch.object(experiment, "fetch_data", return_value=data):
# We have seen three "yes" and three "no", but not enough trials
# are completed
self.assertFalse(
generation_strategy._maybe_move_to_next_step(
raise_data_required_error=False
)
)

experiment._trial_indices_by_status = {TrialStatus.COMPLETED: {*range(6)}}
# Enough trials are completed but we have not seen three "yes" and three
# "no"
self.assertFalse(
generation_strategy._maybe_move_to_next_step(
raise_data_required_error=False
)
self.assertEqual(
gs._steps[2].transition_criteria, step_2_expected_transition_criteria
)

with patch.object(experiment, "fetch_data", return_value=data):
# Enough trials are completed but we have not seen three "yes" and three
# "no"
self.assertTrue(
generation_strategy._maybe_move_to_next_step(
raise_data_required_error=False
)
)

self.assertEqual(generation_strategy._curr.model, Models.GPEI)
# Check default results for `is_met` call
self.assertTrue(gs._steps[0].transition_criteria[0].is_met(experiment))
self.assertTrue(gs._steps[0].transition_criteria[1].is_met(experiment))
self.assertFalse(gs._steps[1].transition_criteria[0].is_met(experiment))
self.assertFalse(gs._steps[1].transition_criteria[1].is_met(experiment))
Loading

0 comments on commit 9aa4a29

Please sign in to comment.