Skip to content

Commit

Permalink
Add transition criteria to GenerationNode (#1887)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/aepsych#320


In this diff we do a few things:
(1) Create a TransitionCriterion class:
- this class will subsume the CompletionCriterion class and will be a bit more flexible
- it has the same child classes + maximumtrialsinstatus subclass, we may add more subclasses or fields later as we further test

(2) Create an list of transitioncriterion from the generationstep class:
- minimum_trials_observed can be taken care of by the more flexible MinimumTrialsInStatus class
- num_trials and enforce_num_trials can be taken care of by the more flexible MaximumTrialsInStatus class

(3) adds a doc string to GenNode class - tangential but easy

(4) updates the type of completion_criteria of GenerationStep from CompletionCriterion to TransitionCriterion


In following diffs we will:
(1) add transition criterion to the repr string + some of the other fields that havent made it yet
(2) begin moving the functions related to completing the step up to node and leveraging the transition criterion for checks instead of indexes -- this is where we may need to add additional fields to transitioncriterion
(3) add doc strings to everywhere in teh GenNode class
(4) add additional unit tests to MaxTrials to bring coverage to 100%
(5) skip max trial criterion addition if numtrials == -1
(6) clean up compeletion_criterion class once new ax release can be pinned to aepsych version

Reviewed By: lena-kashtelyan

Differential Revision: D49509997
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Oct 4, 2023
1 parent 994fc89 commit 6fea95d
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 52 deletions.
66 changes: 35 additions & 31 deletions ax/modelbridge/completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,51 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import abstractmethod
from logging import Logger

from ax.core.base_trial import TrialStatus
from ax.modelbridge.transition_criterion import (
MinimumPreferenceOccurances,
MinimumTrialsInStatus,
TransitionCriterion,
)
from ax.utils.common.logger import get_logger

from ax.core.experiment import Experiment
from ax.utils.common.base import Base
from ax.utils.common.serialization import SerializationMixin
logger: Logger = get_logger(__name__)


class CompletionCriterion(Base, SerializationMixin):
class CompletionCriterion(TransitionCriterion):
"""
Simple class to descibe a condition which must be met for a GenerationStraytegy
to move to its next GenerationStep.
Deprecated class that has been replaced by `TransitionCriterion`, and will be
fully reaped in a future release.
"""

def __init__(self) -> None:
pass
logger.warning(
"CompletionCriterion is deprecated, please use TransitionCriterion instead."
)
pass

@abstractmethod
def is_met(self, experiment: Experiment) -> bool:
pass


class MinimumTrialsInStatus(CompletionCriterion):
def __init__(self, status: TrialStatus, threshold: int) -> None:
self.status = status
self.threshold = threshold

def is_met(self, experiment: Experiment) -> bool:
return len(experiment.trial_indices_by_status[self.status]) >= self.threshold

class MinimumPreferenceOccurances(MinimumPreferenceOccurances):
"""
Deprecated child class that has been replaced by `MinimumPreferenceOccurances`
in `TransitionCriterion`, and will be fully reaped in a future release.
"""

class MinimumPreferenceOccurances(CompletionCriterion):
def __init__(self, metric_name: str, threshold: int) -> None:
self.metric_name = metric_name
self.threshold = threshold
logger.warning(
"CompletionCriterion, which MinimumPreferenceOccurance inherits from, is"
" deprecated. Please use TransitionCriterion instead."
)
pass

def is_met(self, experiment: Experiment) -> bool:
data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]])

count_no = (data.df["mean"] == 0).sum()
count_yes = (data.df["mean"] != 0).sum()
class MinimumTrialsInStatus(MinimumTrialsInStatus):
"""
Deprecated child class that has been replaced by `MinimumTrialsInStatus`
in `TransitionCriterion`, and will be fully reaped in a future release.
"""

return count_no >= self.threshold and count_yes >= self.threshold
logger.warning(
"CompletionCriterion, which MinimumTrialsInStatus inherits from, is"
" deprecated. Please use TransitionCriterion instead."
)
pass
67 changes: 62 additions & 5 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
MaxParallelismReachedException,
)
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.completion_criterion import CompletionCriterion
from ax.modelbridge.cross_validation import BestModelSelector, CVDiagnostics, CVResult
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import ModelRegistryBase
from ax.modelbridge.transition_criterion import (
MaxTrials,
MinimumTrialsInStatus,
TransitionCriterion,
)
from ax.utils.common.base import Base, SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
Expand All @@ -57,15 +61,42 @@


class GenerationNode:
"""Base class for generation node, capable of fitting one or more
model specs under the hood and generating candidates from them.
"""Base class for GenerationNode, capable of fitting one or more model specs under
the hood and generating candidates from them.
Args:
model_specs: A list of ModelSpecs to be selected from for generation in this
GenerationNode
should_deduplicate: Whether to deduplicate the parameters of proposed arms
against those of previous arms via rejection sampling. If this is True,
the GenerationStrategy will discard generator runs produced from the
GenerationNode that has `should_deduplicate=True` if they contain arms
already present on the experiment and replace them with new generator runs.
If no generator run with entirely unique arms could be produced in 5
attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we
assume that the optimization converged when the model can no longer suggest
unique arms.
node_name: A unique name for the GenerationNode. Used for storage purposes.
transition_criteria: List of TransitionCriterion, each of which describes a
condition that must be met before completing a GenerationNode. All `is_met`
must evaluateTrue for the GenerationStrategy to move on to the next
GenerationNode.
Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
on explaining the logic of GenerationStep and GenerationStrategy.
"""

# Required options:
model_specs: List[ModelSpec]
# TODO: Move `should_deduplicate` to `ModelSpec` if possible, and make optional
should_deduplicate: bool
_node_name: str

# Optional specifications
_model_spec_to_gen_from: Optional[ModelSpec] = None
use_update: bool = False
_transition_criteria: Optional[Sequence[TransitionCriterion]]

# [TODO] Handle experiment passing more eloquently by enforcing experiment
# attribute is set in generation strategies class
Expand All @@ -80,6 +111,7 @@ def __init__(
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
use_update: bool = False,
transition_criteria: Optional[Sequence[TransitionCriterion]] = None,
) -> None:
self._node_name = node_name
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
Expand All @@ -91,6 +123,7 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate
self.use_update = use_update
self._transition_criteria = transition_criteria

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -165,6 +198,10 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat
)
return not_none(self._generation_strategy)

@property
def transition_criteria(self) -> Sequence[TransitionCriterion]:
return not_none(self._transition_criteria)

@property
def experiment(self) -> Experiment:
return self.generation_strategy.experiment
Expand Down Expand Up @@ -450,7 +487,7 @@ class GenerationStep(GenerationNode, SortableBase):
model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the
step's model's `gen` under the hood; `model_gen_kwargs` will be passed to
the model's `gen` like so: `model.gen(**model_gen_kwargs)`.
completion_criteria: List of CompletionCriterion. All `is_met` must evaluate
completion_criteria: List of TransitionCriterion. All `is_met` must evaluate
True for the GenerationStrategy to move on to the next Step
index: Index of this generation step, for use internally in `Generation
Strategy`. Do not assign as it will be reassigned when instantiating
Expand Down Expand Up @@ -481,7 +518,7 @@ class GenerationStep(GenerationNode, SortableBase):
model_gen_kwargs: Optional[Dict[str, Any]] = None

# Optional specifications for use in generation strategy:
completion_criteria: Sequence[CompletionCriterion] = field(default_factory=list)
completion_criteria: Sequence[TransitionCriterion] = field(default_factory=list)
min_trials_observed: int = 0
max_parallelism: Optional[int] = None
use_update: bool = False
Expand Down Expand Up @@ -534,11 +571,31 @@ def __post_init__(self) -> None:
except TypeError:
# Factory functions may not always have a model key defined.
self.model_name = f"Unknown {model_spec.__class__.__name__}"

# Create transition criteria for this step. MaximumTrialsInStatus can be used
# to ensure that requirements related to num_trials and enforce_num_trials
# are met. MinimumTrialsInStatus can be used enforce the min_trials_observed
# requirement.
transition_criteria = []
transition_criteria.append(
MaxTrials(
enforce=self.enforce_num_trials,
threshold=self.num_trials,
)
)
transition_criteria.append(
MinimumTrialsInStatus(
status=TrialStatus.COMPLETED, threshold=self.min_trials_observed
)
)
transition_criteria += self.completion_criteria
# need to unwrap old completion_criteria
super().__init__(
node_name=f"GenerationStep_{str(self.index)}",
model_specs=[model_spec],
should_deduplicate=self.should_deduplicate,
use_update=self.use_update,
transition_criteria=transition_criteria,
)

@property
Expand Down
6 changes: 6 additions & 0 deletions ax/modelbridge/tests/test_completion_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@


class TestCompletionCritereon(TestCase):
"""
`CompletionCriterion` is deprecrated and replaced by `TransitionCriterion`.
However, some legacy code still depends on the skelton implementation of
`CompletionCriterion`, so we will keep this test case until full reaping.
"""

def test_single_criterion(self) -> None:
criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3)

Expand Down
141 changes: 141 additions & 0 deletions ax/modelbridge/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import (
MaxTrials,
MinimumPreferenceOccurances,
MinimumTrialsInStatus,
)
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment


class TestTransitionCriterion(TestCase):
def test_minimum_preference_criterion(self) -> None:
"""Tests the minimum preference criterion subcalss of TransitionCriterion."""
criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3)
experiment = get_experiment()
generation_strategy = GenerationStrategy(
name="SOBOL+GPEI::default",
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=-1,
completion_criteria=[criterion],
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
max_parallelism=1,
),
],
)
generation_strategy.experiment = experiment

# Has not seen enough of each preference
self.assertFalse(
generation_strategy._maybe_move_to_next_step(
raise_data_required_error=False
)
)

data = Data(
df=pd.DataFrame(
{
"trial_index": range(6),
"arm_name": [f"{i}_0" for i in range(6)],
"metric_name": ["m1" for _ in range(6)],
"mean": [0, 0, 0, 1, 1, 1],
"sem": [0 for _ in range(6)],
}
)
)
with patch.object(experiment, "fetch_data", return_value=data):
# We have seen three "yes" and three "no"
self.assertTrue(
generation_strategy._maybe_move_to_next_step(
raise_data_required_error=False
)
)
self.assertEqual(generation_strategy._curr.model, Models.GPEI)

def test_default_step_criterion_setup(self) -> None:
"""This test ensures that the default completion criterion for GenerationSteps
is set as expected.
The default completion criterion is to create two TransitionCriterion, one
of type `MaximumTrialsInStatus` and one of type `MinimumTrialsInStatus`.
These are constructed via the inputs of `num_trials`, `enforce_num_trials`,
and `minimum_trials_observed` on the GenerationStep.
"""
experiment = get_experiment()
gs = GenerationStrategy(
name="SOBOL+GPEI::default",
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=3,
enforce_num_trials=False,
),
GenerationStep(
model=Models.GPEI,
num_trials=4,
max_parallelism=1,
min_trials_observed=2,
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
max_parallelism=1,
),
],
)
gs.experiment = experiment

step_0_expected_transition_criteria = [
MaxTrials(threshold=3, enforce=False),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0),
]
step_1_expected_transition_criteria = [
MaxTrials(threshold=4, enforce=True),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=2),
]
step_2_expected_transition_criteria = [
MaxTrials(threshold=-1, enforce=True),
MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0),
]
self.assertEqual(
gs._steps[0].transition_criteria, step_0_expected_transition_criteria
)
self.assertEqual(
gs._steps[1].transition_criteria, step_1_expected_transition_criteria
)
self.assertEqual(
gs._steps[2].transition_criteria, step_2_expected_transition_criteria
)

# Check default results for `is_met` call
self.assertTrue(gs._steps[0].transition_criteria[0].is_met(experiment))
self.assertTrue(gs._steps[0].transition_criteria[1].is_met(experiment))
self.assertFalse(gs._steps[1].transition_criteria[0].is_met(experiment))
self.assertFalse(gs._steps[1].transition_criteria[1].is_met(experiment))

def test_max_trials_status_arg(self) -> None:
"""Tests the `only_in_status` argument checks the threshold based on the
number of trials in specified status instead of all trials (which is the
default behavior).
"""
experiment = get_experiment()
criterion = MaxTrials(
threshold=5, only_in_status=TrialStatus.RUNNING, enforce=True
)
self.assertFalse(criterion.is_met(experiment))
Loading

0 comments on commit 6fea95d

Please sign in to comment.