diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index 74518e3ebe7..d4fbeda78a2 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -14,7 +14,7 @@ from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Models -from ax.models.torch.botorch_modular.model import SurrogateSpec +from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.service.scheduler import SchedulerOptions from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import LogExpectedImprovement diff --git a/ax/benchmark/tests/methods/test_methods.py b/ax/benchmark/tests/methods/test_methods.py index 9c63e2705e0..0dc1aa60cee 100644 --- a/ax/benchmark/tests/methods/test_methods.py +++ b/ax/benchmark/tests/methods/test_methods.py @@ -65,7 +65,7 @@ def _test_mbm_acquisition(self, scheduler_options: SchedulerOptions) -> None: self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient) surrogate_spec = model_kwargs["surrogate_spec"] self.assertEqual( - surrogate_spec.botorch_model_class.__name__, + surrogate_spec.model_configs[0].botorch_model_class.__name__, "SingleTaskGP", ) diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 9d903ad8c9e..27f0ea653a8 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -59,10 +59,8 @@ from ax.models.random.sobol import SobolGenerator from ax.models.random.uniform import UniformGenerator from ax.models.torch.botorch import BotorchModel -from ax.models.torch.botorch_modular.model import ( - BoTorchModel as ModularBoTorchModel, - SurrogateSpec, -) +from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel +from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.models.torch.cbo_sac import SACBO from ax.utils.common.kwargs import ( diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 7f77bbde548..20f17d7ce70 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -28,8 +28,8 @@ from ax.models.random.sobol import SobolGenerator from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.utils.common.kwargs import get_function_argument_names from ax.utils.common.testutils import TestCase @@ -100,7 +100,7 @@ def test_SAASBO(self) -> None: SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( - saasbo.model.surrogate.model_configs[0].botorch_model_class, + saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class, SaasFullyBayesianSingleTaskGP, ) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 03804f6f7de..ebd04a2f2e3 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -10,7 +10,6 @@ import warnings from collections import OrderedDict from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field from typing import Any import numpy.typing as npt @@ -23,12 +22,11 @@ get_rounding_func, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( check_outcome_dataset_match, choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, - ModelConfig, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -38,53 +36,10 @@ from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.deterministic import FixedSingleSampleModel -from botorch.models.model import Model -from botorch.models.transforms.input import InputTransform -from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset -from gpytorch.kernels.kernel import Kernel -from gpytorch.likelihoods import Likelihood -from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor -@dataclass(frozen=True) -class SurrogateSpec: - """ - Fields in the SurrogateSpec dataclass correspond to arguments in - ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which - outcomes the Surrogate is responsible for modeling. - When ``BotorchModel.fit`` is called, these fields will be used to construct the - requisite Surrogate objects. - If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. - """ - - botorch_model_class: type[Model] | None = None - botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) - - mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood - mll_kwargs: dict[str, Any] = field(default_factory=dict) - - covar_module_class: type[Kernel] | None = None - covar_module_kwargs: dict[str, Any] | None = None - - likelihood_class: type[Likelihood] | None = None - likelihood_kwargs: dict[str, Any] | None = None - - input_transform_classes: list[type[InputTransform]] | None = None - input_transform_options: dict[str, dict[str, Any]] | None = None - - outcome_transform_classes: list[type[OutcomeTransform]] | None = None - outcome_transform_options: dict[str, dict[str, Any]] | None = None - - allow_batched_models: bool = True - - model_configs: list[ModelConfig] = field(default_factory=list) - metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) - outcomes: list[str] = field(default_factory=list) - - class BoTorchModel(TorchModel, Base): """**All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha @@ -230,31 +185,17 @@ def fit( # If a surrogate has not been constructed, construct it. if self._surrogate is None: - if (spec := self.surrogate_spec) is not None: - self._surrogate = Surrogate( - botorch_model_class=spec.botorch_model_class, - model_options=spec.botorch_model_kwargs, - mll_class=spec.mll_class, - mll_options=spec.mll_kwargs, - covar_module_class=spec.covar_module_class, - covar_module_options=spec.covar_module_kwargs, - likelihood_class=spec.likelihood_class, - likelihood_options=spec.likelihood_kwargs, - input_transform_classes=spec.input_transform_classes, - input_transform_options=spec.input_transform_options, - outcome_transform_classes=spec.outcome_transform_classes, - outcome_transform_options=spec.outcome_transform_options, - model_configs=spec.model_configs, - metric_to_model_configs=spec.metric_to_model_configs, - allow_batched_models=spec.allow_batched_models, - ) + if self.surrogate_spec is not None: + self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec) else: self._surrogate = Surrogate() # Fit the surrogate. - for config in self.surrogate.model_configs: + for config in self.surrogate.surrogate_spec.model_configs: config.model_options.update(additional_model_inputs) - for config_list in self.surrogate.metric_to_model_configs.values(): + for ( + config_list + ) in self.surrogate.surrogate_spec.metric_to_model_configs.values(): for config in config_list: config.model_options.update(additional_model_inputs) self.surrogate.fit( diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 80ecbd2a484..802ed179b70 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -13,6 +13,7 @@ from collections import OrderedDict from collections.abc import Sequence from copy import deepcopy +from dataclasses import dataclass, field from logging import Logger from typing import Any @@ -271,17 +272,169 @@ def _set_formatted_inputs( formatted_model_inputs[input_name] = input_class(**input_options) -def _raise_deprecation_warning(*args: Any, **kwargs: Any) -> None: +def _raise_deprecation_warning( + is_surrogate: bool = False, + **kwargs: Any, +) -> bool: + """Raise deprecation warnings for deprecated arguments. + + Args: + is_surrogate: A boolean indicating whether the warning is called from + Surrogate. + + Returns: + A boolean indicating whether any deprecation warnings were raised. + """ + msg = "{k} is deprecated and will be removed in a future version. " + if is_surrogate: + msg += "Please specify {k} via `surrogate_spec.model_configs`." + else: + msg += "Please specify {k} via `model_configs`." + warnings_raised = False + default_is_dict = {"botorch_model_kwargs", "mll_kwargs"} for k, v in kwargs.items(): - if (v is not None and k != "mll_class") or ( + should_raise = False + if k in default_is_dict: + if v != {}: + should_raise = True + elif (v is not None and k != "mll_class") or ( k == "mll_class" and v is not ExactMarginalLogLikelihood ): + should_raise = True + if should_raise: warnings.warn( - f"{k} is deprecated and will be removed in a future version. " - f"Please specify {k} via `model_configs`.", + msg.format(k=k), DeprecationWarning, stacklevel=3, ) + warnings_raised = True + return warnings_raised + + +def get_model_config_from_deprecated_args( + botorch_model_class: type[Model] | None, + model_options: dict[str, Any] | None, + mll_class: type[MarginalLogLikelihood] | None, + mll_options: dict[str, Any] | None, + outcome_transform_classes: list[type[OutcomeTransform]] | None, + outcome_transform_options: dict[str, dict[str, Any]] | None, + input_transform_classes: list[type[InputTransform]] | None, + input_transform_options: dict[str, dict[str, Any]] | None, + covar_module_class: type[Kernel] | None, + covar_module_options: dict[str, Any] | None, + likelihood_class: type[Likelihood] | None, + likelihood_options: dict[str, Any] | None, +) -> ModelConfig: + """Construct a ModelConfig from deprecated arguments.""" + model_config_kwargs = { + "botorch_model_class": botorch_model_class, + "model_options": (model_options or {}).copy(), + "mll_class": mll_class, + "mll_options": (mll_options or {}).copy(), + "outcome_transform_classes": outcome_transform_classes, + "outcome_transform_options": outcome_transform_options, + "input_transform_classes": input_transform_classes, + "input_transform_options": input_transform_options, + "covar_module_class": covar_module_class, + "covar_module_options": covar_module_options, + "likelihood_class": likelihood_class, + "likelihood_options": likelihood_options, + } + model_config_kwargs = { + k: v for k, v in model_config_kwargs.items() if v is not None + } + # pyre-fixme [6]: Incompatible parameter type [6]: In call + # `ModelConfig.__init__`, for 1st positional argument, expected + # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], + # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], + # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, + # Model]], Type[Likelihood], Type[Kernel]]`. + return ModelConfig(**model_config_kwargs) + + +@dataclass(frozen=True) +class SurrogateSpec: + """ + Fields in the SurrogateSpec dataclass correspond to arguments in + ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which + outcomes the Surrogate is responsible for modeling. + When ``BotorchModel.fit`` is called, these fields will be used to construct the + requisite Surrogate objects. + If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. + """ + + botorch_model_class: type[Model] | None = None + botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) + + mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood + mll_kwargs: dict[str, Any] = field(default_factory=dict) + + covar_module_class: type[Kernel] | None = None + covar_module_kwargs: dict[str, Any] | None = None + + likelihood_class: type[Likelihood] | None = None + likelihood_kwargs: dict[str, Any] | None = None + + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = None + + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] | None = None + + allow_batched_models: bool = True + + model_configs: list[ModelConfig] = field(default_factory=list) + metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) + outcomes: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + warnings_raised = _raise_deprecation_warning( + is_surrogate=False, + botorch_model_class=self.botorch_model_class, + botorch_model_kwargs=self.botorch_model_kwargs, + mll_class=self.mll_class, + mll_kwargs=self.mll_kwargs, + outcome_transform_classes=self.outcome_transform_classes, + outcome_transform_options=self.outcome_transform_options, + input_transform_classes=self.input_transform_classes, + input_transform_options=self.input_transform_options, + covar_module_class=self.covar_module_class, + covar_module_options=self.covar_module_kwargs, + likelihood_class=self.likelihood_class, + likelihood_options=self.likelihood_kwargs, + ) + if len(self.model_configs) == 0: + model_config = get_model_config_from_deprecated_args( + botorch_model_class=self.botorch_model_class, + model_options=self.botorch_model_kwargs, + mll_class=self.mll_class, + mll_options=self.mll_kwargs, + outcome_transform_classes=self.outcome_transform_classes, + outcome_transform_options=self.outcome_transform_options, + input_transform_classes=self.input_transform_classes, + input_transform_options=self.input_transform_options, + covar_module_class=self.covar_module_class, + covar_module_options=self.covar_module_kwargs, + likelihood_class=self.likelihood_class, + likelihood_options=self.likelihood_kwargs, + ) + # re-initialize with the non-deprecated arguments + self.__init__( + model_configs=[model_config], + metric_to_model_configs=self.metric_to_model_configs, + allow_batched_models=self.allow_batched_models, + outcomes=self.outcomes, + ) + elif warnings_raised: + raise UserInputError( + "model_configs and deprecated arguments were both specified. " + "Please use model_configs and remove deprecated arguments." + ) + if len(self.model_configs) > 1 or any( + len(model_config) > 1 + for model_config in self.metric_to_model_configs.values() + ): + raise NotImplementedError("Only one model config per metric is supported.") class Surrogate(Base): @@ -359,23 +512,23 @@ class string names and the values are dictionaries of input transform def __init__( self, + surrogate_spec: SurrogateSpec | None = None, botorch_model_class: type[Model] | None = None, model_options: dict[str, Any] | None = None, mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, mll_options: dict[str, Any] | None = None, - outcome_transform_classes: Sequence[type[OutcomeTransform]] | None = None, + outcome_transform_classes: list[type[OutcomeTransform]] | None = None, outcome_transform_options: dict[str, dict[str, Any]] | None = None, - input_transform_classes: Sequence[type[InputTransform]] | None = None, + input_transform_classes: list[type[InputTransform]] | None = None, input_transform_options: dict[str, dict[str, Any]] | None = None, covar_module_class: type[Kernel] | None = None, covar_module_options: dict[str, Any] | None = None, likelihood_class: type[Likelihood] | None = None, likelihood_options: dict[str, Any] | None = None, - model_configs: list[ModelConfig] | None = None, - metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> None: - _raise_deprecation_warning( + warnings_raised = _raise_deprecation_warning( + is_surrogate=True, botorch_model_class=botorch_model_class, model_options=model_options, mll_class=mll_class, @@ -389,43 +542,34 @@ def __init__( likelihood_class=likelihood_class, likelihood_options=likelihood_options, ) - model_configs = model_configs or [] - if len(model_configs) == 0: - model_config_kwargs = { - "botorch_model_class": botorch_model_class, - "model_options": (model_options or {}).copy(), - "mll_class": mll_class, - "mll_options": mll_options, - "outcome_transform_classes": outcome_transform_classes, - "outcome_transform_options": outcome_transform_options, - "input_transform_classes": input_transform_classes, - "input_transform_options": input_transform_options, - "covar_module_class": covar_module_class, - "covar_module_options": covar_module_options, - "likelihood_class": likelihood_class, - "likelihood_options": likelihood_options, - } - model_config_kwargs = { - k: v for k, v in model_config_kwargs.items() if v is not None - } - # pyre-fixme [6]: Incompatible parameter type [6]: In call - # `ModelConfig.__init__`, for 1st positional argument, expected - # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], - # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], - # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, - # Model]], Type[Likelihood], Type[Kernel]]`. - model_configs = [ModelConfig(**model_config_kwargs)] - self.model_configs: list[ModelConfig] = model_configs - self.metric_to_model_configs: dict[str, list[ModelConfig]] = ( - metric_to_model_configs or {} - ) - if len(self.model_configs) > 1 or any( - len(model_config) > 1 - for model_config in self.metric_to_model_configs.values() - ): - raise NotImplementedError("Only one model config per metric is supported.") + # check if surrogate_spec is provided + if surrogate_spec is None: + # create surrogate spec from deprecated arguments + model_config = get_model_config_from_deprecated_args( + botorch_model_class=botorch_model_class, + model_options=model_options, + mll_class=mll_class, + mll_options=mll_options, + outcome_transform_classes=outcome_transform_classes, + outcome_transform_options=outcome_transform_options, + input_transform_classes=input_transform_classes, + input_transform_options=input_transform_options, + covar_module_class=covar_module_class, + covar_module_options=covar_module_options, + likelihood_class=likelihood_class, + likelihood_options=likelihood_options, + ) + surrogate_spec = SurrogateSpec( + model_configs=[model_config], allow_batched_models=allow_batched_models + ) + + elif warnings_raised: + raise UserInputError( + "model_configs and deprecated arguments were both specified. " + "Please use model_configs and remove deprecated arguments." + ) - self.allow_batched_models = allow_batched_models + self.surrogate_spec: SurrogateSpec = surrogate_spec # Store the last dataset used to fit the model for a given metric(s). # If the new dataset is identical, we will skip model fitting for that metric. # The keys are `tuple(dataset.outcome_names)`. @@ -445,11 +589,7 @@ def __init__( self._model: Model | None = None def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__}" - f" model_configs={self.model_configs}," - f" metric_to_model_configs={self.metric_to_model_configs}>" - ) + return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>" @property def model(self) -> Model: @@ -614,20 +754,22 @@ def fit( # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. if ( - len(self.model_configs) == 1 - and self.model_configs[0].botorch_model_class is None + len(self.surrogate_spec.model_configs) == 1 + and self.surrogate_spec.model_configs[0].botorch_model_class is None ): default_botorch_model_class = choose_model_class( datasets=datasets, search_space_digest=search_space_digest ) else: - default_botorch_model_class = self.model_configs[0].botorch_model_class + default_botorch_model_class = self.surrogate_spec.model_configs[ + 0 + ].botorch_model_class should_use_model_list = use_model_list( datasets=datasets, botorch_model_class=not_none(default_botorch_model_class), - model_configs=self.model_configs, - allow_batched_models=self.allow_batched_models, - metric_to_model_configs=self.metric_to_model_configs, + model_configs=self.surrogate_spec.model_configs, + allow_batched_models=self.surrogate_spec.allow_batched_models, + metric_to_model_configs=self.surrogate_spec.metric_to_model_configs, ) if not should_use_model_list and len(datasets) > 1: @@ -646,7 +788,7 @@ def fit( else: submodel_state_dict = state_dict model_config = None - if len(self.metric_to_model_configs) > 0: + if len(self.surrogate_spec.metric_to_model_configs) > 0: # if metric_to_model_configs is not empty, then # we are using a model list and each dataset # should have only one outcome. @@ -655,7 +797,7 @@ def fit( "Each dataset should have only one outcome when " "metric_to_model_configs is specified." ) - model_config_list = self.metric_to_model_configs.get( + model_config_list = self.surrogate_spec.metric_to_model_configs.get( dataset.outcome_names[0] ) @@ -663,7 +805,7 @@ def fit( if model_config_list is not None: model_config = model_config_list[0] if model_config is None: - model_config = self.model_configs[0] + model_config = self.surrogate_spec.model_configs[0] model = self._construct_model( dataset=dataset, search_space_digest=search_space_digest, @@ -828,10 +970,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: """Serialize attributes of this surrogate, to be passed back to it as kwargs on reinstantiation. """ - return { - "model_configs": self.model_configs, - "metric_to_model_configs": self.metric_to_model_configs, - } + return {"surrogate_spec": self.surrogate_spec} def _extract_construct_input_transform_args( self, model_config: ModelConfig, search_space_digest: SearchSpaceDigest diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 389b3072175..797719a0f5c 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -21,9 +21,8 @@ from ax.models.torch.botorch_modular.model import ( BoTorchModel, choose_botorch_acqf_class, - SurrogateSpec, ) -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( choose_model_class, construct_acquisition_and_optimizer_options, @@ -719,32 +718,17 @@ def test_evaluate_acquisition_function( def test_surrogate_model_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel( - surrogate_spec=SurrogateSpec( - botorch_model_kwargs={"some_option": "some_value"} - ) + surrogate_spec = SurrogateSpec( + botorch_model_kwargs={"some_option": "some_value"} ) + model = BoTorchModel(surrogate_spec=surrogate_spec) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) mock_init.assert_called_with( - botorch_model_class=None, - model_options={"some_option": "some_value"}, - mll_class=ExactMarginalLogLikelihood, - mll_options={}, - covar_module_class=None, - covar_module_options=None, - likelihood_class=None, - likelihood_options=None, - input_transform_classes=None, - input_transform_options=None, - outcome_transform_classes=None, - outcome_transform_options=None, - model_configs=[], - metric_to_model_configs={}, - allow_batched_models=True, + surrogate_spec=surrogate_spec, ) @mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) @@ -753,28 +737,15 @@ def test_surrogate_model_options_propagation( def test_surrogate_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel(surrogate_spec=SurrogateSpec(allow_batched_models=False)) + surrogate_spec = SurrogateSpec(allow_batched_models=False) + model = BoTorchModel(surrogate_spec=surrogate_spec) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) mock_init.assert_called_with( - botorch_model_class=None, - model_options={}, - mll_class=ExactMarginalLogLikelihood, - mll_options={}, - covar_module_class=None, - covar_module_options=None, - likelihood_class=None, - likelihood_options=None, - input_transform_classes=None, - input_transform_options=None, - outcome_transform_classes=None, - outcome_transform_options=None, - model_configs=[], - metric_to_model_configs={}, - allow_batched_models=False, + surrogate_spec=surrogate_spec, ) @mock.patch( diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 6bbd600df90..eadce0951bf 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -19,7 +19,11 @@ from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate +from ax.models.torch.botorch_modular.surrogate import ( + _extract_model_kwargs, + Surrogate, + SurrogateSpec, +) from ax.models.torch.botorch_modular.utils import ( choose_model_class, fit_botorch_model, @@ -39,7 +43,7 @@ from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood from botorch.models.transforms.input import InputPerturbation, Normalize -from botorch.models.transforms.outcome import Standardize +from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import SupervisedDataset from gpytorch.constraints import GreaterThan, Interval from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel @@ -202,7 +206,7 @@ def _get_surrogate( mll_options = None if use_outcome_transform: - outcome_transform_classes = [Standardize] + outcome_transform_classes: list[type[OutcomeTransform]] = [Standardize] outcome_transform_options = {"Standardize": {"m": 1}} else: outcome_transform_classes = None @@ -222,10 +226,15 @@ def test_init(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) self.assertEqual( - surrogate.model_configs[0].botorch_model_class, botorch_model_class + surrogate.surrogate_spec.model_configs[0].botorch_model_class, + botorch_model_class, + ) + self.assertEqual( + surrogate.surrogate_spec.model_configs[0].mll_class, self.mll_class ) - self.assertEqual(surrogate.model_configs[0].mll_class, self.mll_class) - self.assertTrue(surrogate.allow_batched_models) # True by default + self.assertTrue( + surrogate.surrogate_spec.allow_batched_models + ) # True by default def test_clone_reset(self) -> None: surrogate = self._get_surrogate(botorch_model_class=SingleTaskGP)[0] @@ -454,7 +463,7 @@ def test_construct_model(self) -> None: model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - model_config=surrogate.model_configs[0], + model_config=surrogate.surrogate_spec.model_configs[0], default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, @@ -480,7 +489,7 @@ def test_construct_model(self) -> None: new_model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - model_config=surrogate.model_configs[0], + model_config=surrogate.surrogate_spec.model_configs[0], default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, @@ -517,7 +526,11 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None: "likelihood_class": FixedNoiseGaussianLikelihood, } if use_model_config: - surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(**model_config_kwargs)] + ) + ) else: surrogate = Surrogate(**model_config_kwargs) with self.assertRaisesRegex(UserInputError, "does not support"): @@ -536,7 +549,11 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None: "likelihood_options": {"noise_constraint": noise_constraint}, } if use_model_config: - surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(**model_config_kwargs)] + ) + ) else: surrogate = Surrogate(**model_config_kwargs) surrogate.fit( @@ -552,7 +569,8 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None: noise_constraint.__dict__, ) self.assertEqual( - surrogate.model_configs[0].mll_class, LeaveOneOutPseudoLikelihood + surrogate.surrogate_spec.model_configs[0].mll_class, + LeaveOneOutPseudoLikelihood, ) self.assertEqual(type(model.covar_module), RBFKernel) self.assertEqual(model.covar_module.ard_num_dims, 3) @@ -562,11 +580,13 @@ def test_construct_custom_model_with_config(self) -> None: def test_construct_model_with_metric_to_model_configs(self) -> None: surrogate = Surrogate( - metric_to_model_configs={ - "metric": [ModelConfig()], - "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], - }, - model_configs=[ModelConfig(covar_module_class=LinearKernel)], + surrogate_spec=SurrogateSpec( + metric_to_model_configs={ + "metric": [ModelConfig()], + "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], + }, + model_configs=[ModelConfig(covar_module_class=LinearKernel)], + ) ) training_data = self.training_data + [ SupervisedDataset( @@ -600,13 +620,13 @@ def test_multiple_model_configs_error(self) -> None: with self.assertRaisesRegex( NotImplementedError, "Only one model config per metric is supported." ): - Surrogate( + SurrogateSpec( model_configs=[ModelConfig(), ModelConfig()], ) with self.assertRaisesRegex( NotImplementedError, "Only one model config per metric is supported." ): - Surrogate( + SurrogateSpec( metric_to_model_configs={"metric": [ModelConfig(), ModelConfig()]}, ) @@ -720,10 +740,7 @@ def test_best_out_of_sample_point(self) -> None: def test_serialize_attributes_as_kwargs(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) - expected = { - "model_configs": surrogate.model_configs, - "metric_to_model_configs": surrogate.metric_to_model_configs, - } + expected = {"surrogate_spec": surrogate.surrogate_spec} self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected) @mock_botorch_optimize @@ -752,7 +769,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.model_configs[0].input_transform_classes = [Normalize] + surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.training_data, @@ -905,7 +922,7 @@ def setUp(self) -> None: ) def test_init(self) -> None: - model_config = self.surrogate.model_configs[0] + model_config = self.surrogate.surrogate_spec.model_configs[0] self.assertEqual( [model_config.botorch_model_class] * 2, [*self.botorch_submodel_class_per_outcome.values()], @@ -925,7 +942,9 @@ def test_init(self) -> None: def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock ) -> None: - self.surrogate.model_configs[0].model_options.update({"output_tasks": [2]}) + self.surrogate.surrogate_spec.model_configs[0].model_options.update( + {"output_tasks": [2]} + ) for fixed_noise in (False, True): mock_fit.reset_mock() mock_MTGP_construct_inputs.reset_mock() @@ -1021,7 +1040,7 @@ def test_fit( # pyre-ignore[6]: Incompatible parameter type: In call # `issubclass`, for 1st positional argument, expected # `Type[typing.Any]` but got `Optional[Type[Model]]`. - surrogate.model_configs[0].botorch_model_class, + surrogate.surrogate_spec.model_configs[0].botorch_model_class, MultiTaskGP, ) search_space_digest = ( @@ -1206,7 +1225,8 @@ def test_construct_custom_model(self) -> None: models = checked_cast(ModelListGP, surrogate._model).models self.assertEqual(len(models), 2) self.assertEqual( - surrogate.model_configs[0].mll_class, ExactMarginalLogLikelihood + surrogate.surrogate_spec.model_configs[0].mll_class, + ExactMarginalLogLikelihood, ) # Make sure we properly copied the transforms. self.assertNotEqual( @@ -1260,7 +1280,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.model_configs[0].input_transform_classes = [Normalize] + surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.supervised_training_data, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 5b31efb70f8..140ad1b33bd 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -44,8 +44,7 @@ TransitionCriterion, TrialBasedCriterion, ) -from ax.models.torch.botorch_modular.model import SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig from ax.storage.json_store.decoders import ( batch_trial_from_json, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 02498d03669..3f0aad31e66 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -96,8 +96,8 @@ TransitionCriterion, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 24c8a634e57..a2e17f95acb 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -600,11 +600,11 @@ def test_encode_decode_surrogate_spec(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) - org_as_dict = dataclasses.asdict(org_object) - converted_as_dict = dataclasses.asdict(converted_object) + org_as_dict = dataclasses.asdict(org_object)["model_configs"][0] + converted_as_dict = dataclasses.asdict(converted_object)["model_configs"][0] # Covar module kwargs will fail comparison. Manually compare. - org_covar_kwargs = org_as_dict.pop("covar_module_kwargs") - converted_covar_kwargs = converted_as_dict.pop("covar_module_kwargs") + org_covar_kwargs = org_as_dict.pop("covar_module_options") + converted_covar_kwargs = converted_as_dict.pop("covar_module_options") self.assertEqual(org_covar_kwargs.keys(), converted_covar_kwargs.keys()) for k in org_covar_kwargs: org_ = org_covar_kwargs[k] diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 6aaf96ab9de..8e130319fe2 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -28,7 +28,7 @@ from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.service.scheduler import SchedulerOptions from ax.utils.common.constants import Keys from ax.utils.testing.core_stubs import ( @@ -189,7 +189,11 @@ def get_sobol_gpei_benchmark_method() -> BenchmarkMethod: model=Models.BOTORCH_MODULAR, num_trials=-1, model_kwargs={ - "surrogate": Surrogate(SingleTaskGP), + "surrogate": Surrogate( + surrogate_spec=SurrogateSpec( + botorch_model_class=SingleTaskGP + ) + ), # TODO: tests should better reflect defaults and not # re-implement this logic. "botorch_acqf_class": qNoisyExpectedImprovement, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7d90a339df1..007375df242 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -99,9 +99,9 @@ ) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec +from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.sebo import SEBOAcquisition -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index da21088c63c..dc8f72147ec 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -24,19 +24,20 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904956474, - "executionStopTime": 1730904963470, - "id": "about-preview", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" }, - "originalKey": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7", + "id": "about-preview", + "originalKey": "f4e8ae18-2aa3-4943-a15a-29851889445c", "outputsInitialized": true, - "requestMsgId": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7", - "serverExecutionDuration": 4351.2808320811 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916291451, + "executionStopTime": 1730916298337, + "serverExecutionDuration": 4531.2523420434, + "collapsed": false, + "requestMsgId": "f4e8ae18-2aa3-4943-a15a-29851889445c", + "customOutput": null }, "source": [ "from typing import Any, Dict, Optional, Tuple, Type\n", @@ -48,7 +49,8 @@ "\n", "# Ax wrappers for BoTorch components\n", "from ax.models.torch.botorch_modular.model import BoTorchModel\n", - "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", + "from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec\n", + "from ax.models.torch.botorch_modular.utils import ModelConfig\n", "\n", "# Experiment examination utilities\n", "from ax.service.utils.report_utils import exp_to_df\n", @@ -76,14 +78,14 @@ "output_type": "stream", "name": "stderr", "text": [ - "I1106 065557.352 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" + "I1106 100452.333 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "I1106 065557.353 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" + "I1106 100452.334 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" ] } ] @@ -91,15 +93,15 @@ { "cell_type": "markdown", "metadata": { - "id": "northern-affairs", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" }, + "id": "northern-affairs", "originalKey": "c9a665ca-497e-4d7c-bbb5-1b9f8d1d311c", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "# Setup and Usage of BoTorch Models in Ax\n", @@ -125,15 +127,15 @@ { "cell_type": "markdown", "metadata": { - "id": "pending-support", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" }, + "id": "pending-support", "originalKey": "4706d02e-6b3f-4161-9e08-f5a31328b1d1", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 1. Quick-start example\n", @@ -144,19 +146,20 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904958719, - "executionStopTime": 1730904963489, - "id": "parental-sending", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" }, - "originalKey": "146a9a1b-52e6-4d76-9fc5-79025b392673", + "id": "parental-sending", + "originalKey": "20f25ded-5aae-47ee-955e-a2d5a2a1fe09", "outputsInitialized": true, - "requestMsgId": "146a9a1b-52e6-4d76-9fc5-79025b392673", - "serverExecutionDuration": 42.191333021037 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916294801, + "executionStopTime": 1730916298389, + "serverExecutionDuration": 22.605526028201, + "collapsed": false, + "requestMsgId": "20f25ded-5aae-47ee-955e-a2d5a2a1fe09", + "customOutput": null }, "source": [ "experiment = get_branin_experiment(with_trial=True)\n", @@ -168,7 +171,7 @@ "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:56:01] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + "[INFO 11-06 10:04:56] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" ] } ] @@ -176,19 +179,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904959343, - "executionStopTime": 1730904964892, - "id": "rough-somerset", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" }, - "originalKey": "aa532754-01ad-4441-84c1-2ac7f54ecf1e", + "id": "rough-somerset", + "originalKey": "c0806cce-a1d3-41b8-96fc-678aa3c9dd92", "outputsInitialized": true, - "requestMsgId": "aa532754-01ad-4441-84c1-2ac7f54ecf1e", - "serverExecutionDuration": 870.78339292202 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916295849, + "executionStopTime": 1730916299900, + "serverExecutionDuration": 852.73489891551, + "collapsed": false, + "requestMsgId": "c0806cce-a1d3-41b8-96fc-678aa3c9dd92" }, "source": [ "# `Models` automatically selects a model + model bridge combination.\n", @@ -196,7 +199,9 @@ "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", " data=data,\n", - " surrogate=Surrogate(SingleTaskGP), # Optional, will use default if unspecified\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " ), # Optional, will use default if unspecified\n", " botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified\n", ")" ], @@ -206,7 +211,7 @@ "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:56:01] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + "[INFO 11-06 10:04:57] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" ] } ] @@ -214,15 +219,15 @@ { "cell_type": "markdown", "metadata": { - "id": "hairy-wiring", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" }, + "id": "hairy-wiring", "originalKey": "46f5c2c7-400d-4d8d-b0b9-a241657b173f", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." @@ -231,19 +236,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904961333, - "executionStopTime": 1730904964907, - "id": "consecutive-summary", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" }, - "originalKey": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696", + "id": "consecutive-summary", + "originalKey": "f64e9d2e-bfd4-47da-8292-dbe7e70cbe1f", "outputsInitialized": true, - "requestMsgId": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696", - "serverExecutionDuration": 284.31268292479 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916299852, + "executionStopTime": 1730916300305, + "serverExecutionDuration": 233.20194100961, + "collapsed": false, + "requestMsgId": "f64e9d2e-bfd4-47da-8292-dbe7e70cbe1f" }, "source": [ "generator_run = model_bridge_with_GPEI.gen(n=1)\n", @@ -254,7 +259,7 @@ { "output_type": "execute_result", "data": { - "text/plain": "Arm(parameters={'x1': -5.0, 'x2': 0.0})" + "text/plain": "Arm(parameters={'x1': 10.0, 'x2': 15.0})" }, "metadata": {}, "execution_count": 4 @@ -264,15 +269,15 @@ { "cell_type": "markdown", "metadata": { - "id": "diverse-richards", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" }, + "id": "diverse-richards", "originalKey": "804bac30-db07-4444-98a2-7a5f05007495", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "-----\n", @@ -285,34 +290,34 @@ { "cell_type": "markdown", "metadata": { - "id": "grand-committee", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" }, + "id": "grand-committee", "originalKey": "31b54ce5-2590-4617-b10c-d24ed3cce51d", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 2. BoTorchModel = Surrogate + Acquisition\n", "\n", - "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." + "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The Surrogate is defined by a `SurrogateSpec`. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." ] }, { "cell_type": "markdown", "metadata": { - "id": "thousand-blanket", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" }, + "id": "thousand-blanket", "originalKey": "4a4e006e-07fa-4d63-8b9a-31b67075e40e", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 2A. Example that uses defaults and requires no options\n", @@ -323,23 +328,21 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905022012, - "executionStopTime": 1730905022269, - "id": "changing-xerox", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" }, - "originalKey": "509e30d7-dc32-4190-836f-f221cacbff31", + "id": "changing-xerox", + "originalKey": "fa86552a-0b80-4040-a0c4-61a0de37bdc1", "outputsInitialized": true, - "requestMsgId": "509e30d7-dc32-4190-836f-f221cacbff31", - "serverExecutionDuration": 1.972567057237 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916302730, + "executionStopTime": 1730916304031, + "serverExecutionDuration": 1.7747740494087, + "collapsed": false, + "requestMsgId": "fa86552a-0b80-4040-a0c4-61a0de37bdc1" }, "source": [ - "from ax.models.torch.botorch_modular.utils import ModelConfig\n", - "\n", "# The surrogate is not specified, so it will be auto-selected\n", "# during `model.fit`.\n", "GPEI_model = BoTorchModel(botorch_acqf_class=qLogExpectedImprovement)\n", @@ -347,27 +350,29 @@ "# The acquisition class is not specified, so it will be\n", "# auto-selected during `model.gen` or `model.evaluate_acquisition`\n", "GPEI_model = BoTorchModel(\n", - " surrogate=Surrogate(model_configs=[ModelConfig(SingleTaskGP)])\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " )\n", ")\n", "\n", "# Both the surrogate and acquisition class will be auto-selected.\n", "GPEI_model = BoTorchModel()" ], - "execution_count": 7, + "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "lovely-mechanics", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" }, + "id": "lovely-mechanics", "originalKey": "7b9fae38-fe5d-4e5b-8b5f-2953c1ef09d2", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 2B. Example with all the options\n", @@ -377,24 +382,24 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905023369, - "executionStopTime": 1730905023622, - "id": "twenty-greek", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" }, - "originalKey": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03", + "id": "twenty-greek", + "originalKey": "8d824e37-b087-4bab-9b16-4354e9509df7", "outputsInitialized": true, - "requestMsgId": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03", - "serverExecutionDuration": 1.9794970285147 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916305930, + "executionStopTime": 1730916306168, + "serverExecutionDuration": 2.6916969800368, + "collapsed": false, + "requestMsgId": "8d824e37-b087-4bab-9b16-4354e9509df7" }, "source": [ "model = BoTorchModel(\n", " # Optional `Surrogate` specification to use instead of default\n", - " surrogate=Surrogate(\n", + " surrogate_spec=SurrogateSpec(\n", " model_configs=[\n", " ModelConfig(\n", " # BoTorch `Model` type\n", @@ -421,21 +426,21 @@ " warm_start_refit=True,\n", ")" ], - "execution_count": 8, + "execution_count": 6, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "fourth-material", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" }, + "id": "fourth-material", "originalKey": "7140bb19-09b4-4abe-951d-53902ae07833", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 2C. `Surrogate` and `Acquisition` Q&A\n", @@ -450,15 +455,15 @@ { "cell_type": "markdown", "metadata": { - "id": "violent-course", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" }, + "id": "violent-course", "originalKey": "71f92895-874d-4fc7-ae87-a5519b18d1a0", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" @@ -467,18 +472,18 @@ { "cell_type": "markdown", "metadata": { - "id": "unlike-football", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", "showInput": false }, + "id": "unlike-football", "originalKey": "4af8afa2-5056-46be-b7b9-428127e668cc", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 3a. Making a `Surrogate` from BoTorch `Model`:\n", @@ -486,29 +491,30 @@ "\n", "If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:\n", "1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/main/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/main/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept training data in form of a `SupervisedDataset` container and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.\n", - "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`." + "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via the `model_options` argument to `ModelConfig` in `SurrogateSpec`." ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905050779, - "executionStopTime": 1730905051022, - "id": "dynamic-university", - "isAgentGenerated": false, - "language": "python", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" }, - "originalKey": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7", + "id": "dynamic-university", + "originalKey": "746fc2a3-0e0e-4ab4-84d9-32434eb1fc34", "outputsInitialized": true, - "requestMsgId": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7", - "serverExecutionDuration": 2.5736939860508 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916308518, + "executionStopTime": 1730916308769, + "serverExecutionDuration": 2.4644429795444, + "collapsed": false, + "requestMsgId": "746fc2a3-0e0e-4ab4-84d9-32434eb1fc34" }, "source": [ + "from botorch.models.model import Model\n", "from botorch.utils.datasets import SupervisedDataset\n", "\n", "\n", @@ -530,31 +536,31 @@ " }\n", "\n", "\n", - "surrogate = Surrogate(\n", + "surrogate_spec = SurrogateSpec(\n", " model_configs=[\n", " ModelConfig(\n", " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", " # Optional dict of additional keyword arguments to `MyModelClass`\n", " model_options={},\n", " )\n", - " ],\n", + " ]\n", ")" ], - "execution_count": 9, + "execution_count": 7, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "otherwise-context", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" }, + "id": "otherwise-context", "originalKey": "5a27fd2c-4c4c-41fe-a634-f6d0ec4f1666", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." @@ -563,15 +569,15 @@ { "cell_type": "markdown", "metadata": { - "id": "northern-invite", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" }, + "id": "northern-invite", "originalKey": "df06d02b-95cb-4d34-aac6-773231f1a129", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" @@ -580,47 +586,49 @@ { "cell_type": "markdown", "metadata": { - "id": "surrounded-denial", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", "showInput": false }, + "id": "surrounded-denial", "originalKey": "d4861847-b757-4fcd-9f35-ba258080812c", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Steps to set up any `AcquisitionFunction` in Ax are:\n", - "1. Define an input constructor function. The purpose of this method is to produce arguments to an acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n", + "1. Define an input constructor function. The purpose of this method is to produce arguments to a acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n", " 1. Note that the new input constructor needs to be decorated with `@acqf_input_constructor(AcquisitionFunctionClass)` to register it.\n", - "2. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n", - "3. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`." + "3. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n", + "4. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`." ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905067704, - "executionStopTime": 1730905067939, - "id": "interested-search", - "isAgentGenerated": false, - "language": "python", "metadata": { "code_folding": [], "hidden_ranges": [], "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" }, - "originalKey": "075243f0-c2d6-46fd-9a18-5b77af258abf", + "id": "interested-search", + "originalKey": "f188f40b-64ba-4b0c-b216-f3dea8c7465e", "outputsInitialized": true, - "requestMsgId": "075243f0-c2d6-46fd-9a18-5b77af258abf", - "serverExecutionDuration": 5.0081580411643 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916310518, + "executionStopTime": 1730916310772, + "serverExecutionDuration": 4.9752569757402, + "collapsed": false, + "requestMsgId": "f188f40b-64ba-4b0c-b216-f3dea8c7465e", + "customOutput": null }, "source": [ + "from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse\n", "from botorch.acquisition.acquisition import AcquisitionFunction\n", "from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict\n", "from botorch.utils.datasets import SupervisedDataset\n", @@ -642,6 +650,7 @@ " pass\n", "\n", "\n", + "\n", "# 2-3. Specifying `botorch_acqf_class` and `acquisition_options`\n", "BoTorchModel(\n", " botorch_acqf_class=MyAcquisitionFunctionClass,\n", @@ -655,7 +664,7 @@ " },\n", ")" ], - "execution_count": 10, + "execution_count": 8, "outputs": [ { "output_type": "execute_result", @@ -663,22 +672,22 @@ "text/plain": "BoTorchModel" }, "metadata": {}, - "execution_count": 10 + "execution_count": 8 } ] }, { "cell_type": "markdown", "metadata": { - "id": "metallic-imaging", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" }, + "id": "metallic-imaging", "originalKey": "b057722d-b8ca-47dd-b2c8-1ff4a71c4863", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." @@ -687,15 +696,15 @@ { "cell_type": "markdown", "metadata": { - "id": "descending-australian", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" }, + "id": "descending-australian", "originalKey": "a7406f13-1468-487d-ac5e-7d2a45394850", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 4. Using `Models.BOTORCH_MODULAR` \n", @@ -708,19 +717,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905070804, - "executionStopTime": 1730905071313, - "id": "attached-border", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" }, - "originalKey": "980d8513-8607-4099-8cfe-7c8d7bf5afe9", + "id": "attached-border", + "originalKey": "052cf2e4-8de0-4ec3-a3f9-478194b10928", "outputsInitialized": true, - "requestMsgId": "980d8513-8607-4099-8cfe-7c8d7bf5afe9", - "serverExecutionDuration": 262.54303695168 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916311983, + "executionStopTime": 1730916312395, + "serverExecutionDuration": 202.78578903526, + "collapsed": false, + "requestMsgId": "052cf2e4-8de0-4ec3-a3f9-478194b10928" }, "source": [ "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", @@ -729,13 +738,13 @@ ")\n", "model_bridge_with_GPEI.gen(1)" ], - "execution_count": 11, + "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:51] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + "[INFO 11-06 10:05:12] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" ] }, { @@ -744,31 +753,31 @@ "text/plain": "GeneratorRun(1 arms, total weight 1.0)" }, "metadata": {}, - "execution_count": 11 + "execution_count": 9 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905071744, - "executionStopTime": 1730905071981, - "id": "powerful-gamma", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "89930a31-e058-434b-b587-181931e247b6" }, - "originalKey": "6ec047f0-a75e-4733-b4d7-20045627f0b2", + "id": "powerful-gamma", + "originalKey": "b7f924fe-f3d9-4211-b402-421f4c90afe5", "outputsInitialized": true, - "requestMsgId": "6ec047f0-a75e-4733-b4d7-20045627f0b2", - "serverExecutionDuration": 2.8772989753634 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916312432, + "executionStopTime": 1730916312657, + "serverExecutionDuration": 3.1334219966084, + "collapsed": false, + "requestMsgId": "b7f924fe-f3d9-4211-b402-421f4c90afe5" }, "source": [ "model_bridge_with_GPEI.model.botorch_acqf_class" ], - "execution_count": 12, + "execution_count": 10, "outputs": [ { "output_type": "execute_result", @@ -776,31 +785,31 @@ "text/plain": "botorch.acquisition.logei.qLogNoisyExpectedImprovement" }, "metadata": {}, - "execution_count": 12 + "execution_count": 10 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905076319, - "executionStopTime": 1730905076580, - "id": "improved-replication", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" }, - "originalKey": "68caa729-792d-4692-bc4d-4c0b8d03e022", + "id": "improved-replication", + "originalKey": "942f1817-8d40-48f8-8725-90c25a079e4c", "outputsInitialized": true, - "requestMsgId": "68caa729-792d-4692-bc4d-4c0b8d03e022", - "serverExecutionDuration": 3.2039729412645 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916312847, + "executionStopTime": 1730916313093, + "serverExecutionDuration": 3.410067060031, + "collapsed": false, + "requestMsgId": "942f1817-8d40-48f8-8725-90c25a079e4c" }, "source": [ - "type(model_bridge_with_GPEI.model.surrogate.model)" + "model_bridge_with_GPEI.model.surrogate.model.__class__" ], - "execution_count": 13, + "execution_count": 11, "outputs": [ { "output_type": "execute_result", @@ -808,22 +817,22 @@ "text/plain": "botorch.models.gp_regression.SingleTaskGP" }, "metadata": {}, - "execution_count": 13 + "execution_count": 11 } ] }, { "cell_type": "markdown", "metadata": { - "id": "connected-sheet", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" }, + "id": "connected-sheet", "originalKey": "f5c0adbd-00a6-428d-810f-1e7ed0954b08", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" @@ -832,19 +841,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905078786, - "executionStopTime": 1730905079551, - "id": "documentary-jurisdiction", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" }, - "originalKey": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be", + "id": "documentary-jurisdiction", + "originalKey": "9c64c497-f663-42a6-aa48-1f1f2ae2b80b", "outputsInitialized": true, - "requestMsgId": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be", - "serverExecutionDuration": 512.32317101676 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916314009, + "executionStopTime": 1730916314736, + "serverExecutionDuration": 518.53136904538, + "collapsed": false, + "requestMsgId": "9c64c497-f663-42a6-aa48-1f1f2ae2b80b" }, "source": [ "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", @@ -855,27 +864,27 @@ ")\n", "model_bridge_with_EHVI.gen(1)" ], - "execution_count": 14, + "execution_count": 12, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:59] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + "[INFO 11-06 10:05:14] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" + "[INFO 11-06 10:05:14] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" + "[INFO 11-06 10:05:14] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" ] }, { @@ -884,31 +893,31 @@ "text/plain": "GeneratorRun(1 arms, total weight 1.0)" }, "metadata": {}, - "execution_count": 14 + "execution_count": 12 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905079324, - "executionStopTime": 1730905079651, - "id": "changed-maintenance", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" }, - "originalKey": "87247bf1-04f2-4a8a-92f6-18174d70cbb7", + "id": "changed-maintenance", + "originalKey": "ab6e84ac-2a55-4f48-9ab7-06b8d9b58d1f", "outputsInitialized": true, - "requestMsgId": "87247bf1-04f2-4a8a-92f6-18174d70cbb7", - "serverExecutionDuration": 3.0141109600663 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916314586, + "executionStopTime": 1730916314842, + "serverExecutionDuration": 3.3097150735557, + "collapsed": false, + "requestMsgId": "ab6e84ac-2a55-4f48-9ab7-06b8d9b58d1f" }, "source": [ "model_bridge_with_EHVI.model.botorch_acqf_class" ], - "execution_count": 15, + "execution_count": 13, "outputs": [ { "output_type": "execute_result", @@ -916,31 +925,31 @@ "text/plain": "botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement" }, "metadata": {}, - "execution_count": 15 + "execution_count": 13 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905087115, - "executionStopTime": 1730905087362, - "id": "operating-shelf", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" }, - "originalKey": "22fefbbf-68a9-4d5d-ade9-9df425995c3b", + "id": "operating-shelf", + "originalKey": "1e980e3c-09f6-44c1-a79f-f59867de0c3e", "outputsInitialized": true, - "requestMsgId": "22fefbbf-68a9-4d5d-ade9-9df425995c3b", - "serverExecutionDuration": 3.2659249845892 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916315097, + "executionStopTime": 1730916315308, + "serverExecutionDuration": 3.4662369871512, + "collapsed": false, + "requestMsgId": "1e980e3c-09f6-44c1-a79f-f59867de0c3e" }, "source": [ - "type(model_bridge_with_EHVI.model.surrogate.model)" + "model_bridge_with_EHVI.model.surrogate.model.__class__" ], - "execution_count": 16, + "execution_count": 14, "outputs": [ { "output_type": "execute_result", @@ -948,22 +957,22 @@ "text/plain": "botorch.models.gp_regression.SingleTaskGP" }, "metadata": {}, - "execution_count": 16 + "execution_count": 14 } ] }, { "cell_type": "markdown", "metadata": { - "id": "fatal-butterfly", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" }, + "id": "fatal-butterfly", "originalKey": "3ad7c4a7-fe19-44ad-938d-1be4f8b09bfb", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " @@ -972,15 +981,15 @@ { "cell_type": "markdown", "metadata": { - "id": "hearing-interface", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" }, + "id": "hearing-interface", "originalKey": "44adf1ce-6d3e-455d-b53c-32d3c42a843f", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 5. Utilizing `BoTorchModel` in generation strategies\n", @@ -993,19 +1002,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905104120, - "executionStopTime": 1730905104377, - "id": "received-registration", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" }, - "originalKey": "d0303c24-98bb-4c89-87cb-fa32ff498bd4", + "id": "received-registration", + "originalKey": "4ee172c8-0648-418b-9968-647e8e916507", "outputsInitialized": true, - "requestMsgId": "d0303c24-98bb-4c89-87cb-fa32ff498bd4", - "serverExecutionDuration": 2.4519630242139 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916316730, + "executionStopTime": 1730916316968, + "serverExecutionDuration": 2.2927720565349, + "collapsed": false, + "requestMsgId": "4ee172c8-0648-418b-9968-647e8e916507" }, "source": [ "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", @@ -1028,7 +1037,7 @@ " # No limit on how many generator runs will be produced\n", " num_trials=-1,\n", " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", - " \"surrogate\": Surrogate(\n", + " \"surrogate_spec\": SurrogateSpec(\n", " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", " ),\n", " \"botorch_acqf_class\": qLogNoisyExpectedImprovement,\n", @@ -1037,21 +1046,21 @@ " ]\n", ")" ], - "execution_count": 17, + "execution_count": 15, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "logical-windsor", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" }, + "id": "logical-windsor", "originalKey": "ba3783ee-3d88-4e44-ad07-77de3c50f84d", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" @@ -1060,19 +1069,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905105295, - "executionStopTime": 1730905105570, - "id": "viral-cheese", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" }, - "originalKey": "e8aa4013-eb6f-4f50-a5f0-f963369495ed", + "id": "viral-cheese", + "originalKey": "1b7d0cfc-f7cf-477d-b109-d34db9604938", "outputsInitialized": true, - "requestMsgId": "e8aa4013-eb6f-4f50-a5f0-f963369495ed", - "serverExecutionDuration": 4.1721769375727 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916317751, + "executionStopTime": 1730916318153, + "serverExecutionDuration": 3.9581339806318, + "collapsed": false, + "requestMsgId": "1b7d0cfc-f7cf-477d-b109-d34db9604938" }, "source": [ "experiment = get_branin_experiment(minimize=True)\n", @@ -1080,13 +1089,13 @@ "assert len(experiment.trials) == 0\n", "experiment.search_space" ], - "execution_count": 18, + "execution_count": 16, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:58:25] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + "[INFO 11-06 10:05:18] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" ] }, { @@ -1095,22 +1104,22 @@ "text/plain": "SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])" }, "metadata": {}, - "execution_count": 18 + "execution_count": 16 } ] }, { "cell_type": "markdown", "metadata": { - "id": "incident-newspaper", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" }, + "id": "incident-newspaper", "originalKey": "df2e90f5-4132-4d87-989b-e6d47c748ddc", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 5a. Specifying `pending_observations`\n", @@ -1122,19 +1131,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905107391, - "executionStopTime": 1730905109351, - "id": "casual-spread", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" }, - "originalKey": "36ca75ec-b37c-498c-a487-5652cd3dc34b", + "id": "casual-spread", + "originalKey": "fe7437c5-8834-46cc-94b2-91782d91ee96", "outputsInitialized": true, - "requestMsgId": "36ca75ec-b37c-498c-a487-5652cd3dc34b", - "serverExecutionDuration": 1696.8911510194 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916318830, + "executionStopTime": 1730916321328, + "serverExecutionDuration": 2274.8276960338, + "collapsed": false, + "requestMsgId": "fe7437c5-8834-46cc-94b2-91782d91ee96" }, "source": [ "for _ in range(10):\n", @@ -1155,7 +1164,7 @@ "\n", " print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")" ], - "execution_count": 19, + "execution_count": 17, "outputs": [ { "output_type": "stream", @@ -1175,7 +1184,14 @@ "output_type": "stream", "name": "stdout", "text": [ - "Completed trial #6, suggested by BoTorch.\nCompleted trial #7, suggested by BoTorch.\n" + "Completed trial #6, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #7, suggested by BoTorch.\n" ] }, { @@ -1197,15 +1213,15 @@ { "cell_type": "markdown", "metadata": { - "id": "circular-vermont", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" }, + "id": "circular-vermont", "originalKey": "6a78ef13-fbaa-4cae-934b-d57f5807fe25", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" @@ -1214,37 +1230,37 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905108832, - "executionStopTime": 1730905109399, - "id": "significant-particular", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" }, - "originalKey": "68d567be-27d6-4244-b22f-d6c53ed2d303", + "id": "significant-particular", + "originalKey": "b3160bc0-d5d1-45fa-bf62-4b9dd5778cac", "outputsInitialized": true, - "requestMsgId": "68d567be-27d6-4244-b22f-d6c53ed2d303", - "serverExecutionDuration": 32.219067099504 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916319576, + "executionStopTime": 1730916321368, + "serverExecutionDuration": 35.789265064523, + "collapsed": false, + "requestMsgId": "b3160bc0-d5d1-45fa-bf62-4b9dd5778cac" }, "source": [ "exp_to_df(experiment)" ], - "execution_count": 20, + "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[WARNING 11-06 06:58:29] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" + "[WARNING 11-06 10:05:21] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" ] }, { "output_type": "execute_result", "data": { - "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 79.581199 1.380743 12.280850\n1 1 1_0 COMPLETED ... 17.366840 6.989676 1.438049\n2 2 2_0 COMPLETED ... 61.299075 6.097525 7.568626\n3 3 3_0 COMPLETED ... 71.268812 -3.293570 4.231312\n4 4 4_0 COMPLETED ... 3.831238 -2.268755 10.230113\n5 5 5_0 COMPLETED ... 4.246417 -3.354258 10.886093\n6 6 6_0 COMPLETED ... 6.712767 9.467421 0.000000\n7 7 7_0 COMPLETED ... 17.508300 -5.000000 15.000000\n8 8 8_0 COMPLETED ... 2.507635 10.000000 2.251628\n9 9 9_0 COMPLETED ... 1.318731 8.983844 2.177537\n\n[10 rows x 7 columns]", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
trial_indexarm_nametrial_statusgeneration_methodbraninx1x2
000_0COMPLETEDSobol79.5811991.38074312.280850
111_0COMPLETEDSobol17.3668406.9896761.438049
222_0COMPLETEDSobol61.2990756.0975257.568626
333_0COMPLETEDSobol71.268812-3.2935704.231312
444_0COMPLETEDSobol3.831238-2.26875510.230113
555_0COMPLETEDBoTorch4.246417-3.35425810.886093
666_0COMPLETEDBoTorch6.7127679.4674210.000000
777_0COMPLETEDBoTorch17.508300-5.00000015.000000
888_0COMPLETEDBoTorch2.50763510.0000002.251628
999_0COMPLETEDBoTorch1.3187318.9838442.177537
\n
", + "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 26.922506 -2.244023 5.435609\n1 1 1_0 COMPLETED ... 74.072517 3.535081 10.528676\n2 2 2_0 COMPLETED ... 5.610080 8.741262 3.706691\n3 3 3_0 COMPLETED ... 56.657623 -0.069164 12.199905\n4 4 4_0 COMPLETED ... 27.932704 0.862014 1.306074\n5 5 5_0 COMPLETED ... 5.423062 10.000000 4.868411\n6 6 6_0 COMPLETED ... 9.250452 10.000000 0.299753\n7 7 7_0 COMPLETED ... 308.129096 -5.000000 0.000000\n8 8 8_0 COMPLETED ... 17.607633 0.778687 5.717932\n9 9 9_0 COMPLETED ... 132.986209 1.451895 15.000000\n\n[10 rows x 7 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
trial_indexarm_nametrial_statusgeneration_methodbraninx1x2
000_0COMPLETEDSobol26.922506-2.2440235.435609
111_0COMPLETEDSobol74.0725173.53508110.528676
222_0COMPLETEDSobol5.6100808.7412623.706691
333_0COMPLETEDSobol56.657623-0.06916412.199905
444_0COMPLETEDSobol27.9327040.8620141.306074
555_0COMPLETEDBoTorch5.42306210.0000004.868411
666_0COMPLETEDBoTorch9.25045210.0000000.299753
777_0COMPLETEDBoTorch308.129096-5.0000000.000000
888_0COMPLETEDBoTorch17.6076330.7786875.717932
999_0COMPLETEDBoTorch132.9862091.45189515.000000
\n
", "application/vnd.dataresource+json": { "schema": { "fields": [ @@ -1293,9 +1309,9 @@ "arm_name": "0_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 79.5811993025, - "x1": 1.3807432353, - "x2": 12.2808498144 + "branin": 26.9225058393, + "x1": -2.2440226376, + "x2": 5.4356087744 }, { "index": 1, @@ -1303,9 +1319,9 @@ "arm_name": "1_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 17.3668397479, - "x1": 6.9896756578, - "x2": 1.4380489429 + "branin": 74.0725171307, + "x1": 3.535081069, + "x2": 10.5286756391 }, { "index": 2, @@ -1313,9 +1329,9 @@ "arm_name": "2_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 61.2990749448, - "x1": 6.097525456, - "x2": 7.5686264969 + "branin": 5.6100798162, + "x1": 8.7412616471, + "x2": 3.7066908041 }, { "index": 3, @@ -1323,9 +1339,9 @@ "arm_name": "3_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 71.268812081, - "x1": -3.2935703546, - "x2": 4.231311623 + "branin": 56.6576230229, + "x1": -0.0691637676, + "x2": 12.1999046439 }, { "index": 4, @@ -1333,9 +1349,9 @@ "arm_name": "4_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 3.8312383283, - "x1": -2.2687551333, - "x2": 10.2301133936 + "branin": 27.9327040954, + "x1": 0.8620139305, + "x2": 1.3060741313 }, { "index": 5, @@ -1343,9 +1359,9 @@ "arm_name": "5_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 4.2464169491, - "x1": -3.3542581623, - "x2": 10.8860926765 + "branin": 5.4230616409, + "x1": 10, + "x2": 4.8684112356 }, { "index": 6, @@ -1353,9 +1369,9 @@ "arm_name": "6_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 6.7127667696, - "x1": 9.4674207228, - "x2": 0 + "branin": 9.2504522786, + "x1": 10, + "x2": 0.2997526514 }, { "index": 7, @@ -1363,9 +1379,9 @@ "arm_name": "7_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 17.5082995158, + "branin": 308.1290960116, "x1": -5, - "x2": 15 + "x2": 0 }, { "index": 8, @@ -1373,9 +1389,9 @@ "arm_name": "8_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 2.5076350936, - "x1": 10, - "x2": 2.2516281613 + "branin": 17.6076329851, + "x1": 0.7786866384, + "x2": 5.7179317285 }, { "index": 9, @@ -1383,30 +1399,30 @@ "arm_name": "9_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 1.3187305399, - "x1": 8.983844442, - "x2": 2.1775366828 + "branin": 132.9862090134, + "x1": 1.451894724, + "x2": 15 } ] } }, "metadata": {}, - "execution_count": 20 + "execution_count": 18 } ] }, { "cell_type": "markdown", "metadata": { - "id": "obvious-transparency", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783" }, + "id": "obvious-transparency", "originalKey": "633c66af-a89f-4f03-a88b-866767d0a52f", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 6. Customizing a `Surrogate` or `Acquisition`\n", @@ -1419,19 +1435,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905111356, - "executionStopTime": 1730905111601, - "id": "organizational-balance", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af" }, - "originalKey": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693", + "id": "organizational-balance", + "originalKey": "2949718a-8a4e-41e5-91ac-5b020eface47", "outputsInitialized": true, - "requestMsgId": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693", - "serverExecutionDuration": 2.4916339898482 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916320585, + "executionStopTime": 1730916321384, + "serverExecutionDuration": 2.2059100447223, + "collapsed": false, + "requestMsgId": "2949718a-8a4e-41e5-91ac-5b020eface47" }, "source": [ "from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform\n", @@ -1451,21 +1467,21 @@ " ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:\n", " ... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default" ], - "execution_count": 21, + "execution_count": 19, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "theoretical-horizon", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9" }, + "id": "theoretical-horizon", "originalKey": "0ec8606d-9d5b-4bcb-ad7e-f54839ad6f9b", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Then to use the new subclass in `BoTorchModel`, just specify `acquisition_class` argument along with `botorch_acqf_class` (to `BoTorchModel` directly or to `Models.BOTORCH_MODULAR`, which just passes the relevant arguments to `BoTorchModel` under the hood, as discussed in section 4):" @@ -1474,19 +1490,19 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905113131, - "executionStopTime": 1730905113387, - "id": "approximate-rolling", - "isAgentGenerated": false, - "language": "python", "metadata": { "originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc" }, - "originalKey": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab", + "id": "approximate-rolling", + "originalKey": "e231ea1e-c70d-48dc-b6c6-1611c5ea1b26", "outputsInitialized": true, - "requestMsgId": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab", - "serverExecutionDuration": 12.22031598445 + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916321675, + "executionStopTime": 1730916321901, + "serverExecutionDuration": 12.351316981949, + "collapsed": false, + "requestMsgId": "e231ea1e-c70d-48dc-b6c6-1611c5ea1b26" }, "source": [ "Models.BOTORCH_MODULAR(\n", @@ -1496,13 +1512,13 @@ " botorch_acqf_class=MyAcquisitionFunctionClass,\n", ")" ], - "execution_count": 22, + "execution_count": 20, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:58:33] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + "[INFO 11-06 10:05:21] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" ] }, { @@ -1511,22 +1527,22 @@ "text/plain": "TorchModelBridge(model=BoTorchModel)" }, "metadata": {}, - "execution_count": 22 + "execution_count": 20 } ] }, { "cell_type": "markdown", "metadata": { - "id": "representative-implement", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256" }, + "id": "representative-implement", "originalKey": "cdcfb2bc-3016-4681-9fff-407f28321c3f", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "To use a custom `Surrogate` subclass, pass the `surrogate` argument of that type:\n", @@ -1542,15 +1558,15 @@ { "cell_type": "markdown", "metadata": { - "id": "framed-intermediate", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361" }, + "id": "framed-intermediate", "originalKey": "ff03d674-f584-403f-ba65-f1bab921845b", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "------" @@ -1559,15 +1575,15 @@ { "cell_type": "markdown", "metadata": { - "id": "metropolitan-feedback", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0" }, + "id": "metropolitan-feedback", "originalKey": "f71fcfa1-fc59-4bfb-84d6-b94ea5298bfa", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## Appendix 1: Methods available on `BoTorchModel`\n", @@ -1590,15 +1606,15 @@ { "cell_type": "markdown", "metadata": { - "id": "possible-transsexual", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12" }, + "id": "possible-transsexual", "originalKey": "91cedde4-8911-441f-af05-eb124581cbbc", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## Appendix 2: Default surrogate models and acquisition functions\n", @@ -1617,15 +1633,15 @@ { "cell_type": "markdown", "metadata": { - "id": "continuous-strain", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6" }, + "id": "continuous-strain", "originalKey": "c8b0f933-8df6-479b-aa61-db75ca877624", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A\n", @@ -1636,15 +1652,15 @@ { "cell_type": "markdown", "metadata": { - "id": "broadband-voice", - "isAgentGenerated": false, - "language": "markdown", "metadata": { "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc" }, + "id": "broadband-voice", "originalKey": "4d82f49a-3a8b-42f0-a4f5-5c079b793344", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "The two options for handling this error are:\n",