From 031fa3574155861acdd74b24a184f05cef0d2966 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 4 Oct 2023 15:00:01 -0700 Subject: [PATCH] Add transition criteria to GenerationNode (#1887) Summary: X-link: https://github.com/facebookresearch/aepsych/pull/320 Pull Request resolved: https://github.com/facebook/Ax/pull/1887 In this diff we do a few things: (1) Create a TransitionCriterion class: - this class will subsume the CompletionCriterion class and will be a bit more flexible - it has the same child classes + maximumtrialsinstatus subclass, we may add more subclasses or fields later as we further test (2) Create an list of transitioncriterion from the generationstep class: - minimum_trials_observed can be taken care of by the more flexible MinimumTrialsInStatus class - num_trials and enforce_num_trials can be taken care of by the more flexible MaximumTrialsInStatus class (3) adds a doc string to GenNode class - tangential but easy (4) updates the type of completion_criteria of GenerationStep from CompletionCriterion to TransitionCriterion In following diffs we will: (1) add transition criterion to the repr string + some of the other fields that havent made it yet (2) begin moving the functions related to completing the step up to node and leveraging the transition criterion for checks instead of indexes -- this is where we may need to add additional fields to transitioncriterion (3) add doc strings to everywhere in teh GenNode class (4) add additional unit tests to MaxTrials to bring coverage to 100% (5) skip max trial criterion addition if numtrials == -1 (6) clean up compeletion_criterion class once new ax release can be pinned to aepsych version Reviewed By: lena-kashtelyan Differential Revision: D49509997 fbshipit-source-id: f498eca7a251cbeca2728068bdd9d9b10a50c2c4 --- ax/modelbridge/completion_criterion.py | 66 ++++---- ax/modelbridge/generation_node.py | 67 ++++++++- .../tests/test_completion_criterion.py | 6 + .../tests/test_transition_criterion.py | 141 ++++++++++++++++++ ax/modelbridge/transition_criterion.py | 97 ++++++++++++ ax/storage/json_store/decoder.py | 76 +++++++++- ax/storage/json_store/encoders.py | 7 +- ax/storage/json_store/registry.py | 17 ++- ax/utils/testing/modeling_stubs.py | 2 +- sphinx/source/modelbridge.rst | 6 + 10 files changed, 433 insertions(+), 52 deletions(-) create mode 100644 ax/modelbridge/tests/test_transition_criterion.py create mode 100644 ax/modelbridge/transition_criterion.py diff --git a/ax/modelbridge/completion_criterion.py b/ax/modelbridge/completion_criterion.py index 3bade125098..97a875bce4f 100644 --- a/ax/modelbridge/completion_criterion.py +++ b/ax/modelbridge/completion_criterion.py @@ -3,47 +3,51 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from abc import abstractmethod +from logging import Logger -from ax.core.base_trial import TrialStatus +from ax.modelbridge.transition_criterion import ( + MinimumPreferenceOccurances, + MinimumTrialsInStatus, + TransitionCriterion, +) +from ax.utils.common.logger import get_logger -from ax.core.experiment import Experiment -from ax.utils.common.base import Base -from ax.utils.common.serialization import SerializationMixin +logger: Logger = get_logger(__name__) -class CompletionCriterion(Base, SerializationMixin): +class CompletionCriterion(TransitionCriterion): """ - Simple class to descibe a condition which must be met for a GenerationStraytegy - to move to its next GenerationStep. + Deprecated class that has been replaced by `TransitionCriterion`, and will be + fully reaped in a future release. """ - def __init__(self) -> None: - pass + logger.warning( + "CompletionCriterion is deprecated, please use TransitionCriterion instead." + ) + pass - @abstractmethod - def is_met(self, experiment: Experiment) -> bool: - pass - - -class MinimumTrialsInStatus(CompletionCriterion): - def __init__(self, status: TrialStatus, threshold: int) -> None: - self.status = status - self.threshold = threshold - - def is_met(self, experiment: Experiment) -> bool: - return len(experiment.trial_indices_by_status[self.status]) >= self.threshold +class MinimumPreferenceOccurances(MinimumPreferenceOccurances): + """ + Deprecated child class that has been replaced by `MinimumPreferenceOccurances` + in `TransitionCriterion`, and will be fully reaped in a future release. + """ -class MinimumPreferenceOccurances(CompletionCriterion): - def __init__(self, metric_name: str, threshold: int) -> None: - self.metric_name = metric_name - self.threshold = threshold + logger.warning( + "CompletionCriterion, which MinimumPreferenceOccurance inherits from, is" + " deprecated. Please use TransitionCriterion instead." + ) + pass - def is_met(self, experiment: Experiment) -> bool: - data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]]) - count_no = (data.df["mean"] == 0).sum() - count_yes = (data.df["mean"] != 0).sum() +class MinimumTrialsInStatus(MinimumTrialsInStatus): + """ + Deprecated child class that has been replaced by `MinimumTrialsInStatus` + in `TransitionCriterion`, and will be fully reaped in a future release. + """ - return count_no >= self.threshold and count_yes >= self.threshold + logger.warning( + "CompletionCriterion, which MinimumTrialsInStatus inherits from, is" + " deprecated. Please use TransitionCriterion instead." + ) + pass diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 30ed1eeb8f4..dca4631fca4 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -30,10 +30,14 @@ MaxParallelismReachedException, ) from ax.modelbridge.base import ModelBridge -from ax.modelbridge.completion_criterion import CompletionCriterion from ax.modelbridge.cross_validation import BestModelSelector, CVDiagnostics, CVResult from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec from ax.modelbridge.registry import ModelRegistryBase +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumTrialsInStatus, + TransitionCriterion, +) from ax.utils.common.base import Base, SortableBase from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none @@ -57,15 +61,42 @@ class GenerationNode: - """Base class for generation node, capable of fitting one or more - model specs under the hood and generating candidates from them. + """Base class for GenerationNode, capable of fitting one or more model specs under + the hood and generating candidates from them. + + Args: + model_specs: A list of ModelSpecs to be selected from for generation in this + GenerationNode + should_deduplicate: Whether to deduplicate the parameters of proposed arms + against those of previous arms via rejection sampling. If this is True, + the GenerationStrategy will discard generator runs produced from the + GenerationNode that has `should_deduplicate=True` if they contain arms + already present on the experiment and replace them with new generator runs. + If no generator run with entirely unique arms could be produced in 5 + attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we + assume that the optimization converged when the model can no longer suggest + unique arms. + node_name: A unique name for the GenerationNode. Used for storage purposes. + transition_criteria: List of TransitionCriterion, each of which describes a + condition that must be met before completing a GenerationNode. All `is_met` + must evaluateTrue for the GenerationStrategy to move on to the next + GenerationNode. + + Note for developers: by "model" here we really mean an Ax ModelBridge object, which + contains an Ax Model under the hood. We call it "model" here to simplify and focus + on explaining the logic of GenerationStep and GenerationStrategy. """ + # Required options: model_specs: List[ModelSpec] + # TODO: Move `should_deduplicate` to `ModelSpec` if possible, and make optional should_deduplicate: bool _node_name: str + + # Optional specifications _model_spec_to_gen_from: Optional[ModelSpec] = None use_update: bool = False + _transition_criteria: Optional[Sequence[TransitionCriterion]] # [TODO] Handle experiment passing more eloquently by enforcing experiment # attribute is set in generation strategies class @@ -80,6 +111,7 @@ def __init__( best_model_selector: Optional[BestModelSelector] = None, should_deduplicate: bool = False, use_update: bool = False, + transition_criteria: Optional[Sequence[TransitionCriterion]] = None, ) -> None: self._node_name = node_name # While `GenerationNode` only handles a single `ModelSpec` in the `gen` @@ -91,6 +123,7 @@ def __init__( self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate self.use_update = use_update + self._transition_criteria = transition_criteria @property def node_name(self) -> str: @@ -165,6 +198,10 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat ) return not_none(self._generation_strategy) + @property + def transition_criteria(self) -> Sequence[TransitionCriterion]: + return not_none(self._transition_criteria) + @property def experiment(self) -> Experiment: return self.generation_strategy.experiment @@ -450,7 +487,7 @@ class GenerationStep(GenerationNode, SortableBase): model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the step's model's `gen` under the hood; `model_gen_kwargs` will be passed to the model's `gen` like so: `model.gen(**model_gen_kwargs)`. - completion_criteria: List of CompletionCriterion. All `is_met` must evaluate + completion_criteria: List of TransitionCriterion. All `is_met` must evaluate True for the GenerationStrategy to move on to the next Step index: Index of this generation step, for use internally in `Generation Strategy`. Do not assign as it will be reassigned when instantiating @@ -481,7 +518,7 @@ class GenerationStep(GenerationNode, SortableBase): model_gen_kwargs: Optional[Dict[str, Any]] = None # Optional specifications for use in generation strategy: - completion_criteria: Sequence[CompletionCriterion] = field(default_factory=list) + completion_criteria: Sequence[TransitionCriterion] = field(default_factory=list) min_trials_observed: int = 0 max_parallelism: Optional[int] = None use_update: bool = False @@ -534,11 +571,31 @@ def __post_init__(self) -> None: except TypeError: # Factory functions may not always have a model key defined. self.model_name = f"Unknown {model_spec.__class__.__name__}" + + # Create transition criteria for this step. MaximumTrialsInStatus can be used + # to ensure that requirements related to num_trials and enforce_num_trials + # are met. MinimumTrialsInStatus can be used enforce the min_trials_observed + # requirement. + transition_criteria = [] + transition_criteria.append( + MaxTrials( + enforce=self.enforce_num_trials, + threshold=self.num_trials, + ) + ) + transition_criteria.append( + MinimumTrialsInStatus( + status=TrialStatus.COMPLETED, threshold=self.min_trials_observed + ) + ) + transition_criteria += self.completion_criteria + # need to unwrap old completion_criteria super().__init__( node_name=f"GenerationStep_{str(self.index)}", model_specs=[model_spec], should_deduplicate=self.should_deduplicate, use_update=self.use_update, + transition_criteria=transition_criteria, ) @property diff --git a/ax/modelbridge/tests/test_completion_criterion.py b/ax/modelbridge/tests/test_completion_criterion.py index f6726ccb199..b4d00b9c6cc 100644 --- a/ax/modelbridge/tests/test_completion_criterion.py +++ b/ax/modelbridge/tests/test_completion_criterion.py @@ -19,6 +19,12 @@ class TestCompletionCritereon(TestCase): + """ + `CompletionCriterion` is deprecrated and replaced by `TransitionCriterion`. + However, some legacy code still depends on the skelton implementation of + `CompletionCriterion`, so we will keep this test case until full reaping. + """ + def test_single_criterion(self) -> None: criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3) diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py new file mode 100644 index 00000000000..6806485b6cd --- /dev/null +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pandas as pd +from ax.core.base_trial import TrialStatus +from ax.core.data import Data +from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.modelbridge.registry import Models +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumPreferenceOccurances, + MinimumTrialsInStatus, +) +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_experiment + + +class TestTransitionCriterion(TestCase): + def test_minimum_preference_criterion(self) -> None: + """Tests the minimum preference criterion subcalss of TransitionCriterion.""" + criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3) + experiment = get_experiment() + generation_strategy = GenerationStrategy( + name="SOBOL+GPEI::default", + steps=[ + GenerationStep( + model=Models.SOBOL, + num_trials=-1, + completion_criteria=[criterion], + ), + GenerationStep( + model=Models.GPEI, + num_trials=-1, + max_parallelism=1, + ), + ], + ) + generation_strategy.experiment = experiment + + # Has not seen enough of each preference + self.assertFalse( + generation_strategy._maybe_move_to_next_step( + raise_data_required_error=False + ) + ) + + data = Data( + df=pd.DataFrame( + { + "trial_index": range(6), + "arm_name": [f"{i}_0" for i in range(6)], + "metric_name": ["m1" for _ in range(6)], + "mean": [0, 0, 0, 1, 1, 1], + "sem": [0 for _ in range(6)], + } + ) + ) + with patch.object(experiment, "fetch_data", return_value=data): + # We have seen three "yes" and three "no" + self.assertTrue( + generation_strategy._maybe_move_to_next_step( + raise_data_required_error=False + ) + ) + self.assertEqual(generation_strategy._curr.model, Models.GPEI) + + def test_default_step_criterion_setup(self) -> None: + """This test ensures that the default completion criterion for GenerationSteps + is set as expected. + + The default completion criterion is to create two TransitionCriterion, one + of type `MaximumTrialsInStatus` and one of type `MinimumTrialsInStatus`. + These are constructed via the inputs of `num_trials`, `enforce_num_trials`, + and `minimum_trials_observed` on the GenerationStep. + """ + experiment = get_experiment() + gs = GenerationStrategy( + name="SOBOL+GPEI::default", + steps=[ + GenerationStep( + model=Models.SOBOL, + num_trials=3, + enforce_num_trials=False, + ), + GenerationStep( + model=Models.GPEI, + num_trials=4, + max_parallelism=1, + min_trials_observed=2, + ), + GenerationStep( + model=Models.GPEI, + num_trials=-1, + max_parallelism=1, + ), + ], + ) + gs.experiment = experiment + + step_0_expected_transition_criteria = [ + MaxTrials(threshold=3, enforce=False), + MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0), + ] + step_1_expected_transition_criteria = [ + MaxTrials(threshold=4, enforce=True), + MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=2), + ] + step_2_expected_transition_criteria = [ + MaxTrials(threshold=-1, enforce=True), + MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0), + ] + self.assertEqual( + gs._steps[0].transition_criteria, step_0_expected_transition_criteria + ) + self.assertEqual( + gs._steps[1].transition_criteria, step_1_expected_transition_criteria + ) + self.assertEqual( + gs._steps[2].transition_criteria, step_2_expected_transition_criteria + ) + + # Check default results for `is_met` call + self.assertTrue(gs._steps[0].transition_criteria[0].is_met(experiment)) + self.assertTrue(gs._steps[0].transition_criteria[1].is_met(experiment)) + self.assertFalse(gs._steps[1].transition_criteria[0].is_met(experiment)) + self.assertFalse(gs._steps[1].transition_criteria[1].is_met(experiment)) + + def test_max_trials_status_arg(self) -> None: + """Tests the `only_in_status` argument checks the threshold based on the + number of trials in specified status instead of all trials (which is the + default behavior). + """ + experiment = get_experiment() + criterion = MaxTrials( + threshold=5, only_in_status=TrialStatus.RUNNING, enforce=True + ) + self.assertFalse(criterion.is_met(experiment)) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py new file mode 100644 index 00000000000..86056e48b19 --- /dev/null +++ b/ax/modelbridge/transition_criterion.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from typing import Optional + +from ax.core.base_trial import TrialStatus + +from ax.core.experiment import Experiment +from ax.utils.common.base import Base +from ax.utils.common.serialization import SerializationMixin + + +class TransitionCriterion(Base, SerializationMixin): + """ + Simple class to descibe a condition which must be met for a GenerationStrategy + to move to its next GenerationNode. + """ + + # TODO: @mgarrard add `transition_to` attribute to define the next node + def __init__(self) -> None: + pass + + @abstractmethod + def is_met(self, experiment: Experiment) -> bool: + pass + + +class MinimumTrialsInStatus(TransitionCriterion): + """ + Simple class to decide if the number of trials of a given status in the + GenerationStrategy experiment has reached a certain threshold. + """ + + def __init__(self, status: TrialStatus, threshold: int) -> None: + self.status = status + self.threshold = threshold + + def is_met(self, experiment: Experiment) -> bool: + return len(experiment.trial_indices_by_status[self.status]) >= self.threshold + + +class MaxTrials(TransitionCriterion): + """ + Simple class to enforce a maximum threshold for the number of trials generated + by a specific GenerationNode. + + Args: + threshold: the designated maximum number of trials + enforce: whether or not to enforce the max trial constraint + only_in_status: optional argument for specifying only checking trials with + this status. If not specified, all trial statuses are counted. + """ + + def __init__( + self, + threshold: int, + enforce: bool, + only_in_status: Optional[TrialStatus] = None, + ) -> None: + self.threshold = threshold + self.enforce = enforce + # Optional argument for specifying only checking trials with this status + self.only_in_status = only_in_status + + def is_met(self, experiment: Experiment) -> bool: + if self.enforce: + if self.only_in_status is not None: + return ( + len(experiment.trial_indices_by_status[self.only_in_status]) + >= self.threshold + ) + return experiment.num_trials >= self.threshold + return True + + +class MinimumPreferenceOccurances(TransitionCriterion): + """ + In a preference Experiment (i.e. Metric values may either be zero for No and + nonzero for Yes) do not transition until a minimum number of both Yes and No + responses have been received. + """ + + def __init__(self, metric_name: str, threshold: int) -> None: + self.metric_name = metric_name + self.threshold = threshold + + def is_met(self, experiment: Experiment) -> bool: + # TODO: @mgarrard replace fetch_data with lookup_data + data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]]) + + count_no = (data.df["mean"] == 0).sum() + count_yes = (data.df["mean"] != 0).sum() + + return count_no >= self.threshold and count_yes >= self.threshold diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index ab5852fbde7..c267dd74685 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -32,6 +32,12 @@ from ax.exceptions.storage import JSONDecodeError from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import _decode_callables_from_references +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumPreferenceOccurances, + MinimumTrialsInStatus, + TransitionCriterion, +) from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.storage.json_store.decoders import ( @@ -40,6 +46,7 @@ tensor_from_json, trial_from_json, ) + from ax.storage.json_store.registry import ( CORE_CLASS_DECODER_REGISTRY, CORE_DECODER_REGISTRY, @@ -318,6 +325,56 @@ def generator_run_from_json( return generator_run +def transition_criteria_from_json( + transition_criteria_json: List[Dict[str, Any]], + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use + # `typing.Type` to avoid runtime subscripting errors. + decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + class_decoder_registry: Dict[ + str, Callable[[Dict[str, Any]], Any] + ] = CORE_CLASS_DECODER_REGISTRY, +) -> Optional[List[TransitionCriterion]]: + """Load Ax TransitionCriteria from JSON. + + This function is necessary due to the loading of TrialStatus in + some, but not all, TransitionCriterion. + """ + if transition_criteria_json is None: + return None + + criterion_list = [] + for criterion_json in transition_criteria_json: + criterion_type = criterion_json.pop("__type") + if criterion_type == "MinimumTrialsInStatus": + criterion_list.append( + MinimumTrialsInStatus( + status=object_from_json(criterion_json.pop("status")), + threshold=criterion_json.pop("threshold"), + ) + ) + elif criterion_type == "MaxTrials": + criterion_list.append( + MaxTrials( + only_in_status=object_from_json( + criterion_json.pop("only_in_status") + ) + if "only_in_status" in criterion_json.keys() + else None, + threshold=criterion_json.pop("threshold"), + enforce=criterion_json.pop("enforce"), + ) + ) + else: + criterion_list.append( + MinimumPreferenceOccurances( + metric_name=criterion_json.pop("metric_name"), + threshold=criterion_json.pop("threshold"), + ), + ) + return criterion_list + + def search_space_from_json( search_space_json: Dict[str, Any], # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use @@ -635,7 +692,12 @@ def generation_step_from_json( ) kwargs = generation_step_json.pop("model_kwargs", None) gen_kwargs = generation_step_json.pop("model_gen_kwargs", None) - return GenerationStep( + completion_criteria = ( + transition_criteria_from_json(generation_step_json.pop("completion_criteria")) + if "completion_criteria" in generation_step_json.keys() + else [] + ) + generation_step = GenerationStep( model=object_from_json( generation_step_json.pop("model"), decoder_registry=decoder_registry, @@ -643,10 +705,8 @@ def generation_step_from_json( ), num_trials=generation_step_json.pop("num_trials"), min_trials_observed=generation_step_json.pop("min_trials_observed", 0), - completion_criteria=object_from_json( - generation_step_json.pop("completion_criteria") - ) - if "completion_criteria" in generation_step_json.keys() + completion_criteria=completion_criteria + if completion_criteria is not None else [], max_parallelism=(generation_step_json.pop("max_parallelism", None)), use_update=generation_step_json.pop("use_update", False), @@ -674,6 +734,12 @@ def generation_step_from_json( if "should_deduplicate" in generation_step_json else False, ) + generation_step._transition_criteria = transition_criteria_from_json( + generation_step_json.pop("transition_criteria") + if "transition_criteria" in generation_step_json.keys() + else None + ) + return generation_step def generation_strategy_from_json( diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 86e8b79c721..d6a75a3da01 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -49,10 +49,10 @@ from ax.exceptions.core import AxStorageWarning from ax.exceptions.storage import JSONEncodeError from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy -from ax.modelbridge.completion_criterion import CompletionCriterion from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import _encode_callables_as_references from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transition_criterion import TransitionCriterion from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.winsorization_config import WinsorizationConfig @@ -475,6 +475,7 @@ def generation_step_to_dict(generation_step: GenerationStep) -> Dict[str, Any]: ), "index": generation_step.index, "should_deduplicate": generation_step.should_deduplicate, + "transition_criteria": generation_step.transition_criteria, } @@ -499,8 +500,8 @@ def generation_strategy_to_dict( } -def completion_criterion_to_dict(criterion: CompletionCriterion) -> Dict[str, Any]: - """Convert Ax CompletionCriterion to a dictionary.""" +def transition_criterion_to_dict(criterion: TransitionCriterion) -> Dict[str, Any]: + """Convert Ax TransitionCriterion to a dictionary.""" properties = criterion.serialize_init_args(obj=criterion) properties["__type"] = criterion.__class__.__name__ return properties diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 5678e56a513..944dbcfc0bc 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -75,13 +75,14 @@ from ax.metrics.l2norm import L2NormMetric from ax.metrics.noisy_function import NoisyFunctionMetric from ax.metrics.sklearn import SklearnDataset, SklearnMetric, SklearnModelType -from ax.modelbridge.completion_criterion import ( - MinimumPreferenceOccurances, - MinimumTrialsInStatus, -) from ax.modelbridge.factory import Models from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumPreferenceOccurances, + MinimumTrialsInStatus, +) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -104,7 +105,6 @@ botorch_model_to_dict, botorch_modular_to_dict, choice_parameter_to_dict, - completion_criterion_to_dict, data_to_dict, experiment_to_dict, fixed_parameter_to_dict, @@ -141,6 +141,7 @@ surrogate_to_dict, threshold_early_stopping_strategy_to_dict, transform_type_to_dict, + transition_criterion_to_dict, trial_to_dict, winsorization_config_to_dict, ) @@ -189,8 +190,9 @@ MapKeyInfo: map_key_info_to_dict, MapMetric: metric_to_dict, Metric: metric_to_dict, - MinimumTrialsInStatus: completion_criterion_to_dict, - MinimumPreferenceOccurances: completion_criterion_to_dict, + MinimumTrialsInStatus: transition_criterion_to_dict, + MinimumPreferenceOccurances: transition_criterion_to_dict, + MaxTrials: transition_criterion_to_dict, MultiObjective: multi_objective_to_dict, MultiObjectiveBenchmarkProblem: multi_objective_benchmark_problem_to_dict, MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict, @@ -300,6 +302,7 @@ "MapData": MapData, "MapMetric": MapMetric, "MapKeyInfo": MapKeyInfo, + "MaxTrials": MaxTrials, "Metric": Metric, "MinimumTrialsInStatus": MinimumTrialsInStatus, "MinimumPreferenceOccurances": MinimumPreferenceOccurances, diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 4066c9c45a7..d2cca3c76d9 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -14,12 +14,12 @@ from ax.core.parameter import FixedParameter, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.base import ModelBridge -from ax.modelbridge.completion_criterion import MinimumPreferenceOccurances from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Models from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.int_to_float import IntToFloat +from ax.modelbridge.transition_criterion import MinimumPreferenceOccurances from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.logger import get_logger from ax.utils.testing.core_stubs import ( diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 0c2aa4d2361..920a728ab41 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -31,6 +31,12 @@ Completion Criterion :undoc-members: :show-inheritance: +Transition Criterion +.. automodule:: ax.modelbridge.transition_criterion + :members: + :undoc-members: + :show-inheritance: + Registry ~~~~~~~~~~~~~~~~~~~~ .. automodule:: ax.modelbridge.registry