diff --git a/ax/modelbridge/completion_criterion.py b/ax/modelbridge/completion_criterion.py index 3bade125098..97a875bce4f 100644 --- a/ax/modelbridge/completion_criterion.py +++ b/ax/modelbridge/completion_criterion.py @@ -3,47 +3,51 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from abc import abstractmethod +from logging import Logger -from ax.core.base_trial import TrialStatus +from ax.modelbridge.transition_criterion import ( + MinimumPreferenceOccurances, + MinimumTrialsInStatus, + TransitionCriterion, +) +from ax.utils.common.logger import get_logger -from ax.core.experiment import Experiment -from ax.utils.common.base import Base -from ax.utils.common.serialization import SerializationMixin +logger: Logger = get_logger(__name__) -class CompletionCriterion(Base, SerializationMixin): +class CompletionCriterion(TransitionCriterion): """ - Simple class to descibe a condition which must be met for a GenerationStraytegy - to move to its next GenerationStep. + Deprecated class that has been replaced by `TransitionCriterion`, and will be + fully reaped in a future release. """ - def __init__(self) -> None: - pass + logger.warning( + "CompletionCriterion is deprecated, please use TransitionCriterion instead." + ) + pass - @abstractmethod - def is_met(self, experiment: Experiment) -> bool: - pass - - -class MinimumTrialsInStatus(CompletionCriterion): - def __init__(self, status: TrialStatus, threshold: int) -> None: - self.status = status - self.threshold = threshold - - def is_met(self, experiment: Experiment) -> bool: - return len(experiment.trial_indices_by_status[self.status]) >= self.threshold +class MinimumPreferenceOccurances(MinimumPreferenceOccurances): + """ + Deprecated child class that has been replaced by `MinimumPreferenceOccurances` + in `TransitionCriterion`, and will be fully reaped in a future release. + """ -class MinimumPreferenceOccurances(CompletionCriterion): - def __init__(self, metric_name: str, threshold: int) -> None: - self.metric_name = metric_name - self.threshold = threshold + logger.warning( + "CompletionCriterion, which MinimumPreferenceOccurance inherits from, is" + " deprecated. Please use TransitionCriterion instead." + ) + pass - def is_met(self, experiment: Experiment) -> bool: - data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]]) - count_no = (data.df["mean"] == 0).sum() - count_yes = (data.df["mean"] != 0).sum() +class MinimumTrialsInStatus(MinimumTrialsInStatus): + """ + Deprecated child class that has been replaced by `MinimumTrialsInStatus` + in `TransitionCriterion`, and will be fully reaped in a future release. + """ - return count_no >= self.threshold and count_yes >= self.threshold + logger.warning( + "CompletionCriterion, which MinimumTrialsInStatus inherits from, is" + " deprecated. Please use TransitionCriterion instead." + ) + pass diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 30ed1eeb8f4..dca4631fca4 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -30,10 +30,14 @@ MaxParallelismReachedException, ) from ax.modelbridge.base import ModelBridge -from ax.modelbridge.completion_criterion import CompletionCriterion from ax.modelbridge.cross_validation import BestModelSelector, CVDiagnostics, CVResult from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec from ax.modelbridge.registry import ModelRegistryBase +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumTrialsInStatus, + TransitionCriterion, +) from ax.utils.common.base import Base, SortableBase from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none @@ -57,15 +61,42 @@ class GenerationNode: - """Base class for generation node, capable of fitting one or more - model specs under the hood and generating candidates from them. + """Base class for GenerationNode, capable of fitting one or more model specs under + the hood and generating candidates from them. + + Args: + model_specs: A list of ModelSpecs to be selected from for generation in this + GenerationNode + should_deduplicate: Whether to deduplicate the parameters of proposed arms + against those of previous arms via rejection sampling. If this is True, + the GenerationStrategy will discard generator runs produced from the + GenerationNode that has `should_deduplicate=True` if they contain arms + already present on the experiment and replace them with new generator runs. + If no generator run with entirely unique arms could be produced in 5 + attempts, a `GenerationStrategyRepeatedPoints` error will be raised, as we + assume that the optimization converged when the model can no longer suggest + unique arms. + node_name: A unique name for the GenerationNode. Used for storage purposes. + transition_criteria: List of TransitionCriterion, each of which describes a + condition that must be met before completing a GenerationNode. All `is_met` + must evaluateTrue for the GenerationStrategy to move on to the next + GenerationNode. + + Note for developers: by "model" here we really mean an Ax ModelBridge object, which + contains an Ax Model under the hood. We call it "model" here to simplify and focus + on explaining the logic of GenerationStep and GenerationStrategy. """ + # Required options: model_specs: List[ModelSpec] + # TODO: Move `should_deduplicate` to `ModelSpec` if possible, and make optional should_deduplicate: bool _node_name: str + + # Optional specifications _model_spec_to_gen_from: Optional[ModelSpec] = None use_update: bool = False + _transition_criteria: Optional[Sequence[TransitionCriterion]] # [TODO] Handle experiment passing more eloquently by enforcing experiment # attribute is set in generation strategies class @@ -80,6 +111,7 @@ def __init__( best_model_selector: Optional[BestModelSelector] = None, should_deduplicate: bool = False, use_update: bool = False, + transition_criteria: Optional[Sequence[TransitionCriterion]] = None, ) -> None: self._node_name = node_name # While `GenerationNode` only handles a single `ModelSpec` in the `gen` @@ -91,6 +123,7 @@ def __init__( self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate self.use_update = use_update + self._transition_criteria = transition_criteria @property def node_name(self) -> str: @@ -165,6 +198,10 @@ def generation_strategy(self) -> modelbridge.generation_strategy.GenerationStrat ) return not_none(self._generation_strategy) + @property + def transition_criteria(self) -> Sequence[TransitionCriterion]: + return not_none(self._transition_criteria) + @property def experiment(self) -> Experiment: return self.generation_strategy.experiment @@ -450,7 +487,7 @@ class GenerationStep(GenerationNode, SortableBase): model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the step's model's `gen` under the hood; `model_gen_kwargs` will be passed to the model's `gen` like so: `model.gen(**model_gen_kwargs)`. - completion_criteria: List of CompletionCriterion. All `is_met` must evaluate + completion_criteria: List of TransitionCriterion. All `is_met` must evaluate True for the GenerationStrategy to move on to the next Step index: Index of this generation step, for use internally in `Generation Strategy`. Do not assign as it will be reassigned when instantiating @@ -481,7 +518,7 @@ class GenerationStep(GenerationNode, SortableBase): model_gen_kwargs: Optional[Dict[str, Any]] = None # Optional specifications for use in generation strategy: - completion_criteria: Sequence[CompletionCriterion] = field(default_factory=list) + completion_criteria: Sequence[TransitionCriterion] = field(default_factory=list) min_trials_observed: int = 0 max_parallelism: Optional[int] = None use_update: bool = False @@ -534,11 +571,31 @@ def __post_init__(self) -> None: except TypeError: # Factory functions may not always have a model key defined. self.model_name = f"Unknown {model_spec.__class__.__name__}" + + # Create transition criteria for this step. MaximumTrialsInStatus can be used + # to ensure that requirements related to num_trials and enforce_num_trials + # are met. MinimumTrialsInStatus can be used enforce the min_trials_observed + # requirement. + transition_criteria = [] + transition_criteria.append( + MaxTrials( + enforce=self.enforce_num_trials, + threshold=self.num_trials, + ) + ) + transition_criteria.append( + MinimumTrialsInStatus( + status=TrialStatus.COMPLETED, threshold=self.min_trials_observed + ) + ) + transition_criteria += self.completion_criteria + # need to unwrap old completion_criteria super().__init__( node_name=f"GenerationStep_{str(self.index)}", model_specs=[model_spec], should_deduplicate=self.should_deduplicate, use_update=self.use_update, + transition_criteria=transition_criteria, ) @property diff --git a/ax/modelbridge/tests/test_completion_criterion.py b/ax/modelbridge/tests/test_completion_criterion.py index f6726ccb199..b4d00b9c6cc 100644 --- a/ax/modelbridge/tests/test_completion_criterion.py +++ b/ax/modelbridge/tests/test_completion_criterion.py @@ -19,6 +19,12 @@ class TestCompletionCritereon(TestCase): + """ + `CompletionCriterion` is deprecrated and replaced by `TransitionCriterion`. + However, some legacy code still depends on the skelton implementation of + `CompletionCriterion`, so we will keep this test case until full reaping. + """ + def test_single_criterion(self) -> None: criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3) diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py new file mode 100644 index 00000000000..6806485b6cd --- /dev/null +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import pandas as pd +from ax.core.base_trial import TrialStatus +from ax.core.data import Data +from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.modelbridge.registry import Models +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumPreferenceOccurances, + MinimumTrialsInStatus, +) +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_experiment + + +class TestTransitionCriterion(TestCase): + def test_minimum_preference_criterion(self) -> None: + """Tests the minimum preference criterion subcalss of TransitionCriterion.""" + criterion = MinimumPreferenceOccurances(metric_name="m1", threshold=3) + experiment = get_experiment() + generation_strategy = GenerationStrategy( + name="SOBOL+GPEI::default", + steps=[ + GenerationStep( + model=Models.SOBOL, + num_trials=-1, + completion_criteria=[criterion], + ), + GenerationStep( + model=Models.GPEI, + num_trials=-1, + max_parallelism=1, + ), + ], + ) + generation_strategy.experiment = experiment + + # Has not seen enough of each preference + self.assertFalse( + generation_strategy._maybe_move_to_next_step( + raise_data_required_error=False + ) + ) + + data = Data( + df=pd.DataFrame( + { + "trial_index": range(6), + "arm_name": [f"{i}_0" for i in range(6)], + "metric_name": ["m1" for _ in range(6)], + "mean": [0, 0, 0, 1, 1, 1], + "sem": [0 for _ in range(6)], + } + ) + ) + with patch.object(experiment, "fetch_data", return_value=data): + # We have seen three "yes" and three "no" + self.assertTrue( + generation_strategy._maybe_move_to_next_step( + raise_data_required_error=False + ) + ) + self.assertEqual(generation_strategy._curr.model, Models.GPEI) + + def test_default_step_criterion_setup(self) -> None: + """This test ensures that the default completion criterion for GenerationSteps + is set as expected. + + The default completion criterion is to create two TransitionCriterion, one + of type `MaximumTrialsInStatus` and one of type `MinimumTrialsInStatus`. + These are constructed via the inputs of `num_trials`, `enforce_num_trials`, + and `minimum_trials_observed` on the GenerationStep. + """ + experiment = get_experiment() + gs = GenerationStrategy( + name="SOBOL+GPEI::default", + steps=[ + GenerationStep( + model=Models.SOBOL, + num_trials=3, + enforce_num_trials=False, + ), + GenerationStep( + model=Models.GPEI, + num_trials=4, + max_parallelism=1, + min_trials_observed=2, + ), + GenerationStep( + model=Models.GPEI, + num_trials=-1, + max_parallelism=1, + ), + ], + ) + gs.experiment = experiment + + step_0_expected_transition_criteria = [ + MaxTrials(threshold=3, enforce=False), + MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0), + ] + step_1_expected_transition_criteria = [ + MaxTrials(threshold=4, enforce=True), + MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=2), + ] + step_2_expected_transition_criteria = [ + MaxTrials(threshold=-1, enforce=True), + MinimumTrialsInStatus(status=TrialStatus.COMPLETED, threshold=0), + ] + self.assertEqual( + gs._steps[0].transition_criteria, step_0_expected_transition_criteria + ) + self.assertEqual( + gs._steps[1].transition_criteria, step_1_expected_transition_criteria + ) + self.assertEqual( + gs._steps[2].transition_criteria, step_2_expected_transition_criteria + ) + + # Check default results for `is_met` call + self.assertTrue(gs._steps[0].transition_criteria[0].is_met(experiment)) + self.assertTrue(gs._steps[0].transition_criteria[1].is_met(experiment)) + self.assertFalse(gs._steps[1].transition_criteria[0].is_met(experiment)) + self.assertFalse(gs._steps[1].transition_criteria[1].is_met(experiment)) + + def test_max_trials_status_arg(self) -> None: + """Tests the `only_in_status` argument checks the threshold based on the + number of trials in specified status instead of all trials (which is the + default behavior). + """ + experiment = get_experiment() + criterion = MaxTrials( + threshold=5, only_in_status=TrialStatus.RUNNING, enforce=True + ) + self.assertFalse(criterion.is_met(experiment)) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py new file mode 100644 index 00000000000..86056e48b19 --- /dev/null +++ b/ax/modelbridge/transition_criterion.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from typing import Optional + +from ax.core.base_trial import TrialStatus + +from ax.core.experiment import Experiment +from ax.utils.common.base import Base +from ax.utils.common.serialization import SerializationMixin + + +class TransitionCriterion(Base, SerializationMixin): + """ + Simple class to descibe a condition which must be met for a GenerationStrategy + to move to its next GenerationNode. + """ + + # TODO: @mgarrard add `transition_to` attribute to define the next node + def __init__(self) -> None: + pass + + @abstractmethod + def is_met(self, experiment: Experiment) -> bool: + pass + + +class MinimumTrialsInStatus(TransitionCriterion): + """ + Simple class to decide if the number of trials of a given status in the + GenerationStrategy experiment has reached a certain threshold. + """ + + def __init__(self, status: TrialStatus, threshold: int) -> None: + self.status = status + self.threshold = threshold + + def is_met(self, experiment: Experiment) -> bool: + return len(experiment.trial_indices_by_status[self.status]) >= self.threshold + + +class MaxTrials(TransitionCriterion): + """ + Simple class to enforce a maximum threshold for the number of trials generated + by a specific GenerationNode. + + Args: + threshold: the designated maximum number of trials + enforce: whether or not to enforce the max trial constraint + only_in_status: optional argument for specifying only checking trials with + this status. If not specified, all trial statuses are counted. + """ + + def __init__( + self, + threshold: int, + enforce: bool, + only_in_status: Optional[TrialStatus] = None, + ) -> None: + self.threshold = threshold + self.enforce = enforce + # Optional argument for specifying only checking trials with this status + self.only_in_status = only_in_status + + def is_met(self, experiment: Experiment) -> bool: + if self.enforce: + if self.only_in_status is not None: + return ( + len(experiment.trial_indices_by_status[self.only_in_status]) + >= self.threshold + ) + return experiment.num_trials >= self.threshold + return True + + +class MinimumPreferenceOccurances(TransitionCriterion): + """ + In a preference Experiment (i.e. Metric values may either be zero for No and + nonzero for Yes) do not transition until a minimum number of both Yes and No + responses have been received. + """ + + def __init__(self, metric_name: str, threshold: int) -> None: + self.metric_name = metric_name + self.threshold = threshold + + def is_met(self, experiment: Experiment) -> bool: + # TODO: @mgarrard replace fetch_data with lookup_data + data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]]) + + count_no = (data.df["mean"] == 0).sum() + count_yes = (data.df["mean"] != 0).sum() + + return count_no >= self.threshold and count_yes >= self.threshold diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index ab5852fbde7..c267dd74685 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -32,6 +32,12 @@ from ax.exceptions.storage import JSONDecodeError from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import _decode_callables_from_references +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumPreferenceOccurances, + MinimumTrialsInStatus, + TransitionCriterion, +) from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.storage.json_store.decoders import ( @@ -40,6 +46,7 @@ tensor_from_json, trial_from_json, ) + from ax.storage.json_store.registry import ( CORE_CLASS_DECODER_REGISTRY, CORE_DECODER_REGISTRY, @@ -318,6 +325,56 @@ def generator_run_from_json( return generator_run +def transition_criteria_from_json( + transition_criteria_json: List[Dict[str, Any]], + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use + # `typing.Type` to avoid runtime subscripting errors. + decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + class_decoder_registry: Dict[ + str, Callable[[Dict[str, Any]], Any] + ] = CORE_CLASS_DECODER_REGISTRY, +) -> Optional[List[TransitionCriterion]]: + """Load Ax TransitionCriteria from JSON. + + This function is necessary due to the loading of TrialStatus in + some, but not all, TransitionCriterion. + """ + if transition_criteria_json is None: + return None + + criterion_list = [] + for criterion_json in transition_criteria_json: + criterion_type = criterion_json.pop("__type") + if criterion_type == "MinimumTrialsInStatus": + criterion_list.append( + MinimumTrialsInStatus( + status=object_from_json(criterion_json.pop("status")), + threshold=criterion_json.pop("threshold"), + ) + ) + elif criterion_type == "MaxTrials": + criterion_list.append( + MaxTrials( + only_in_status=object_from_json( + criterion_json.pop("only_in_status") + ) + if "only_in_status" in criterion_json.keys() + else None, + threshold=criterion_json.pop("threshold"), + enforce=criterion_json.pop("enforce"), + ) + ) + else: + criterion_list.append( + MinimumPreferenceOccurances( + metric_name=criterion_json.pop("metric_name"), + threshold=criterion_json.pop("threshold"), + ), + ) + return criterion_list + + def search_space_from_json( search_space_json: Dict[str, Any], # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use @@ -635,7 +692,12 @@ def generation_step_from_json( ) kwargs = generation_step_json.pop("model_kwargs", None) gen_kwargs = generation_step_json.pop("model_gen_kwargs", None) - return GenerationStep( + completion_criteria = ( + transition_criteria_from_json(generation_step_json.pop("completion_criteria")) + if "completion_criteria" in generation_step_json.keys() + else [] + ) + generation_step = GenerationStep( model=object_from_json( generation_step_json.pop("model"), decoder_registry=decoder_registry, @@ -643,10 +705,8 @@ def generation_step_from_json( ), num_trials=generation_step_json.pop("num_trials"), min_trials_observed=generation_step_json.pop("min_trials_observed", 0), - completion_criteria=object_from_json( - generation_step_json.pop("completion_criteria") - ) - if "completion_criteria" in generation_step_json.keys() + completion_criteria=completion_criteria + if completion_criteria is not None else [], max_parallelism=(generation_step_json.pop("max_parallelism", None)), use_update=generation_step_json.pop("use_update", False), @@ -674,6 +734,12 @@ def generation_step_from_json( if "should_deduplicate" in generation_step_json else False, ) + generation_step._transition_criteria = transition_criteria_from_json( + generation_step_json.pop("transition_criteria") + if "transition_criteria" in generation_step_json.keys() + else None + ) + return generation_step def generation_strategy_from_json( diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 86e8b79c721..d6a75a3da01 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -49,10 +49,10 @@ from ax.exceptions.core import AxStorageWarning from ax.exceptions.storage import JSONEncodeError from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy -from ax.modelbridge.completion_criterion import CompletionCriterion from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import _encode_callables_as_references from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transition_criterion import TransitionCriterion from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.winsorization_config import WinsorizationConfig @@ -475,6 +475,7 @@ def generation_step_to_dict(generation_step: GenerationStep) -> Dict[str, Any]: ), "index": generation_step.index, "should_deduplicate": generation_step.should_deduplicate, + "transition_criteria": generation_step.transition_criteria, } @@ -499,8 +500,8 @@ def generation_strategy_to_dict( } -def completion_criterion_to_dict(criterion: CompletionCriterion) -> Dict[str, Any]: - """Convert Ax CompletionCriterion to a dictionary.""" +def transition_criterion_to_dict(criterion: TransitionCriterion) -> Dict[str, Any]: + """Convert Ax TransitionCriterion to a dictionary.""" properties = criterion.serialize_init_args(obj=criterion) properties["__type"] = criterion.__class__.__name__ return properties diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 5678e56a513..944dbcfc0bc 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -75,13 +75,14 @@ from ax.metrics.l2norm import L2NormMetric from ax.metrics.noisy_function import NoisyFunctionMetric from ax.metrics.sklearn import SklearnDataset, SklearnMetric, SklearnModelType -from ax.modelbridge.completion_criterion import ( - MinimumPreferenceOccurances, - MinimumTrialsInStatus, -) from ax.modelbridge.factory import Models from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transition_criterion import ( + MaxTrials, + MinimumPreferenceOccurances, + MinimumTrialsInStatus, +) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -104,7 +105,6 @@ botorch_model_to_dict, botorch_modular_to_dict, choice_parameter_to_dict, - completion_criterion_to_dict, data_to_dict, experiment_to_dict, fixed_parameter_to_dict, @@ -141,6 +141,7 @@ surrogate_to_dict, threshold_early_stopping_strategy_to_dict, transform_type_to_dict, + transition_criterion_to_dict, trial_to_dict, winsorization_config_to_dict, ) @@ -189,8 +190,9 @@ MapKeyInfo: map_key_info_to_dict, MapMetric: metric_to_dict, Metric: metric_to_dict, - MinimumTrialsInStatus: completion_criterion_to_dict, - MinimumPreferenceOccurances: completion_criterion_to_dict, + MinimumTrialsInStatus: transition_criterion_to_dict, + MinimumPreferenceOccurances: transition_criterion_to_dict, + MaxTrials: transition_criterion_to_dict, MultiObjective: multi_objective_to_dict, MultiObjectiveBenchmarkProblem: multi_objective_benchmark_problem_to_dict, MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict, @@ -300,6 +302,7 @@ "MapData": MapData, "MapMetric": MapMetric, "MapKeyInfo": MapKeyInfo, + "MaxTrials": MaxTrials, "Metric": Metric, "MinimumTrialsInStatus": MinimumTrialsInStatus, "MinimumPreferenceOccurances": MinimumPreferenceOccurances, diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 4066c9c45a7..d2cca3c76d9 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -14,12 +14,12 @@ from ax.core.parameter import FixedParameter, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.base import ModelBridge -from ax.modelbridge.completion_criterion import MinimumPreferenceOccurances from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Models from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.int_to_float import IntToFloat +from ax.modelbridge.transition_criterion import MinimumPreferenceOccurances from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.logger import get_logger from ax.utils.testing.core_stubs import ( diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 0c2aa4d2361..920a728ab41 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -31,6 +31,12 @@ Completion Criterion :undoc-members: :show-inheritance: +Transition Criterion +.. automodule:: ax.modelbridge.transition_criterion + :members: + :undoc-members: + :show-inheritance: + Registry ~~~~~~~~~~~~~~~~~~~~ .. automodule:: ax.modelbridge.registry