From a0c14265ecbdf828b91bd02b8695e04b125ffbaf Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 6 Nov 2024 08:15:37 -0800 Subject: [PATCH] move deprecation warnings to SurrogateSpec and pass surrogate_spec to Surrogate (#3025) Summary: See title. This avoids defining the same arguments in two places, per feedback from Sait. Moving deprecation warnings to SurrogateSpec means warnings are raised raised while specifying the model rather than when it is instantiated. Reviewed By: saitcakmak Differential Revision: D65321401 --- ax/benchmark/methods/modular_botorch.py | 2 +- ax/benchmark/tests/methods/test_methods.py | 2 +- ax/modelbridge/registry.py | 6 +- ax/modelbridge/tests/test_registry.py | 6 +- ax/models/torch/botorch_modular/model.py | 73 +- ax/models/torch/botorch_modular/surrogate.py | 265 +++-- ax/models/torch/tests/test_model.py | 45 +- ax/models/torch/tests/test_surrogate.py | 76 +- ax/storage/json_store/decoder.py | 3 +- ax/storage/json_store/registry.py | 4 +- .../json_store/tests/test_json_store.py | 8 +- ax/utils/testing/benchmark_stubs.py | 8 +- ax/utils/testing/core_stubs.py | 4 +- tutorials/modular_botax.ipynb | 908 +++++++----------- 14 files changed, 640 insertions(+), 770 deletions(-) diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index 74518e3ebe7..d4fbeda78a2 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -14,7 +14,7 @@ from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Models -from ax.models.torch.botorch_modular.model import SurrogateSpec +from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.service.scheduler import SchedulerOptions from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import LogExpectedImprovement diff --git a/ax/benchmark/tests/methods/test_methods.py b/ax/benchmark/tests/methods/test_methods.py index 9c63e2705e0..0dc1aa60cee 100644 --- a/ax/benchmark/tests/methods/test_methods.py +++ b/ax/benchmark/tests/methods/test_methods.py @@ -65,7 +65,7 @@ def _test_mbm_acquisition(self, scheduler_options: SchedulerOptions) -> None: self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient) surrogate_spec = model_kwargs["surrogate_spec"] self.assertEqual( - surrogate_spec.botorch_model_class.__name__, + surrogate_spec.model_configs[0].botorch_model_class.__name__, "SingleTaskGP", ) diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 9d903ad8c9e..27f0ea653a8 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -59,10 +59,8 @@ from ax.models.random.sobol import SobolGenerator from ax.models.random.uniform import UniformGenerator from ax.models.torch.botorch import BotorchModel -from ax.models.torch.botorch_modular.model import ( - BoTorchModel as ModularBoTorchModel, - SurrogateSpec, -) +from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel +from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.models.torch.cbo_sac import SACBO from ax.utils.common.kwargs import ( diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 7f77bbde548..20f17d7ce70 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -28,8 +28,8 @@ from ax.models.random.sobol import SobolGenerator from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.utils.common.kwargs import get_function_argument_names from ax.utils.common.testutils import TestCase @@ -100,7 +100,7 @@ def test_SAASBO(self) -> None: SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( - saasbo.model.surrogate.model_configs[0].botorch_model_class, + saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class, SaasFullyBayesianSingleTaskGP, ) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 03804f6f7de..ebd04a2f2e3 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -10,7 +10,6 @@ import warnings from collections import OrderedDict from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field from typing import Any import numpy.typing as npt @@ -23,12 +22,11 @@ get_rounding_func, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( check_outcome_dataset_match, choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, - ModelConfig, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -38,53 +36,10 @@ from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.deterministic import FixedSingleSampleModel -from botorch.models.model import Model -from botorch.models.transforms.input import InputTransform -from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset -from gpytorch.kernels.kernel import Kernel -from gpytorch.likelihoods import Likelihood -from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor -@dataclass(frozen=True) -class SurrogateSpec: - """ - Fields in the SurrogateSpec dataclass correspond to arguments in - ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which - outcomes the Surrogate is responsible for modeling. - When ``BotorchModel.fit`` is called, these fields will be used to construct the - requisite Surrogate objects. - If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. - """ - - botorch_model_class: type[Model] | None = None - botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) - - mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood - mll_kwargs: dict[str, Any] = field(default_factory=dict) - - covar_module_class: type[Kernel] | None = None - covar_module_kwargs: dict[str, Any] | None = None - - likelihood_class: type[Likelihood] | None = None - likelihood_kwargs: dict[str, Any] | None = None - - input_transform_classes: list[type[InputTransform]] | None = None - input_transform_options: dict[str, dict[str, Any]] | None = None - - outcome_transform_classes: list[type[OutcomeTransform]] | None = None - outcome_transform_options: dict[str, dict[str, Any]] | None = None - - allow_batched_models: bool = True - - model_configs: list[ModelConfig] = field(default_factory=list) - metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) - outcomes: list[str] = field(default_factory=list) - - class BoTorchModel(TorchModel, Base): """**All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha @@ -230,31 +185,17 @@ def fit( # If a surrogate has not been constructed, construct it. if self._surrogate is None: - if (spec := self.surrogate_spec) is not None: - self._surrogate = Surrogate( - botorch_model_class=spec.botorch_model_class, - model_options=spec.botorch_model_kwargs, - mll_class=spec.mll_class, - mll_options=spec.mll_kwargs, - covar_module_class=spec.covar_module_class, - covar_module_options=spec.covar_module_kwargs, - likelihood_class=spec.likelihood_class, - likelihood_options=spec.likelihood_kwargs, - input_transform_classes=spec.input_transform_classes, - input_transform_options=spec.input_transform_options, - outcome_transform_classes=spec.outcome_transform_classes, - outcome_transform_options=spec.outcome_transform_options, - model_configs=spec.model_configs, - metric_to_model_configs=spec.metric_to_model_configs, - allow_batched_models=spec.allow_batched_models, - ) + if self.surrogate_spec is not None: + self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec) else: self._surrogate = Surrogate() # Fit the surrogate. - for config in self.surrogate.model_configs: + for config in self.surrogate.surrogate_spec.model_configs: config.model_options.update(additional_model_inputs) - for config_list in self.surrogate.metric_to_model_configs.values(): + for ( + config_list + ) in self.surrogate.surrogate_spec.metric_to_model_configs.values(): for config in config_list: config.model_options.update(additional_model_inputs) self.surrogate.fit( diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 80ecbd2a484..802ed179b70 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -13,6 +13,7 @@ from collections import OrderedDict from collections.abc import Sequence from copy import deepcopy +from dataclasses import dataclass, field from logging import Logger from typing import Any @@ -271,17 +272,169 @@ def _set_formatted_inputs( formatted_model_inputs[input_name] = input_class(**input_options) -def _raise_deprecation_warning(*args: Any, **kwargs: Any) -> None: +def _raise_deprecation_warning( + is_surrogate: bool = False, + **kwargs: Any, +) -> bool: + """Raise deprecation warnings for deprecated arguments. + + Args: + is_surrogate: A boolean indicating whether the warning is called from + Surrogate. + + Returns: + A boolean indicating whether any deprecation warnings were raised. + """ + msg = "{k} is deprecated and will be removed in a future version. " + if is_surrogate: + msg += "Please specify {k} via `surrogate_spec.model_configs`." + else: + msg += "Please specify {k} via `model_configs`." + warnings_raised = False + default_is_dict = {"botorch_model_kwargs", "mll_kwargs"} for k, v in kwargs.items(): - if (v is not None and k != "mll_class") or ( + should_raise = False + if k in default_is_dict: + if v != {}: + should_raise = True + elif (v is not None and k != "mll_class") or ( k == "mll_class" and v is not ExactMarginalLogLikelihood ): + should_raise = True + if should_raise: warnings.warn( - f"{k} is deprecated and will be removed in a future version. " - f"Please specify {k} via `model_configs`.", + msg.format(k=k), DeprecationWarning, stacklevel=3, ) + warnings_raised = True + return warnings_raised + + +def get_model_config_from_deprecated_args( + botorch_model_class: type[Model] | None, + model_options: dict[str, Any] | None, + mll_class: type[MarginalLogLikelihood] | None, + mll_options: dict[str, Any] | None, + outcome_transform_classes: list[type[OutcomeTransform]] | None, + outcome_transform_options: dict[str, dict[str, Any]] | None, + input_transform_classes: list[type[InputTransform]] | None, + input_transform_options: dict[str, dict[str, Any]] | None, + covar_module_class: type[Kernel] | None, + covar_module_options: dict[str, Any] | None, + likelihood_class: type[Likelihood] | None, + likelihood_options: dict[str, Any] | None, +) -> ModelConfig: + """Construct a ModelConfig from deprecated arguments.""" + model_config_kwargs = { + "botorch_model_class": botorch_model_class, + "model_options": (model_options or {}).copy(), + "mll_class": mll_class, + "mll_options": (mll_options or {}).copy(), + "outcome_transform_classes": outcome_transform_classes, + "outcome_transform_options": outcome_transform_options, + "input_transform_classes": input_transform_classes, + "input_transform_options": input_transform_options, + "covar_module_class": covar_module_class, + "covar_module_options": covar_module_options, + "likelihood_class": likelihood_class, + "likelihood_options": likelihood_options, + } + model_config_kwargs = { + k: v for k, v in model_config_kwargs.items() if v is not None + } + # pyre-fixme [6]: Incompatible parameter type [6]: In call + # `ModelConfig.__init__`, for 1st positional argument, expected + # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], + # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], + # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, + # Model]], Type[Likelihood], Type[Kernel]]`. + return ModelConfig(**model_config_kwargs) + + +@dataclass(frozen=True) +class SurrogateSpec: + """ + Fields in the SurrogateSpec dataclass correspond to arguments in + ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which + outcomes the Surrogate is responsible for modeling. + When ``BotorchModel.fit`` is called, these fields will be used to construct the + requisite Surrogate objects. + If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. + """ + + botorch_model_class: type[Model] | None = None + botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) + + mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood + mll_kwargs: dict[str, Any] = field(default_factory=dict) + + covar_module_class: type[Kernel] | None = None + covar_module_kwargs: dict[str, Any] | None = None + + likelihood_class: type[Likelihood] | None = None + likelihood_kwargs: dict[str, Any] | None = None + + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = None + + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] | None = None + + allow_batched_models: bool = True + + model_configs: list[ModelConfig] = field(default_factory=list) + metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) + outcomes: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + warnings_raised = _raise_deprecation_warning( + is_surrogate=False, + botorch_model_class=self.botorch_model_class, + botorch_model_kwargs=self.botorch_model_kwargs, + mll_class=self.mll_class, + mll_kwargs=self.mll_kwargs, + outcome_transform_classes=self.outcome_transform_classes, + outcome_transform_options=self.outcome_transform_options, + input_transform_classes=self.input_transform_classes, + input_transform_options=self.input_transform_options, + covar_module_class=self.covar_module_class, + covar_module_options=self.covar_module_kwargs, + likelihood_class=self.likelihood_class, + likelihood_options=self.likelihood_kwargs, + ) + if len(self.model_configs) == 0: + model_config = get_model_config_from_deprecated_args( + botorch_model_class=self.botorch_model_class, + model_options=self.botorch_model_kwargs, + mll_class=self.mll_class, + mll_options=self.mll_kwargs, + outcome_transform_classes=self.outcome_transform_classes, + outcome_transform_options=self.outcome_transform_options, + input_transform_classes=self.input_transform_classes, + input_transform_options=self.input_transform_options, + covar_module_class=self.covar_module_class, + covar_module_options=self.covar_module_kwargs, + likelihood_class=self.likelihood_class, + likelihood_options=self.likelihood_kwargs, + ) + # re-initialize with the non-deprecated arguments + self.__init__( + model_configs=[model_config], + metric_to_model_configs=self.metric_to_model_configs, + allow_batched_models=self.allow_batched_models, + outcomes=self.outcomes, + ) + elif warnings_raised: + raise UserInputError( + "model_configs and deprecated arguments were both specified. " + "Please use model_configs and remove deprecated arguments." + ) + if len(self.model_configs) > 1 or any( + len(model_config) > 1 + for model_config in self.metric_to_model_configs.values() + ): + raise NotImplementedError("Only one model config per metric is supported.") class Surrogate(Base): @@ -359,23 +512,23 @@ class string names and the values are dictionaries of input transform def __init__( self, + surrogate_spec: SurrogateSpec | None = None, botorch_model_class: type[Model] | None = None, model_options: dict[str, Any] | None = None, mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, mll_options: dict[str, Any] | None = None, - outcome_transform_classes: Sequence[type[OutcomeTransform]] | None = None, + outcome_transform_classes: list[type[OutcomeTransform]] | None = None, outcome_transform_options: dict[str, dict[str, Any]] | None = None, - input_transform_classes: Sequence[type[InputTransform]] | None = None, + input_transform_classes: list[type[InputTransform]] | None = None, input_transform_options: dict[str, dict[str, Any]] | None = None, covar_module_class: type[Kernel] | None = None, covar_module_options: dict[str, Any] | None = None, likelihood_class: type[Likelihood] | None = None, likelihood_options: dict[str, Any] | None = None, - model_configs: list[ModelConfig] | None = None, - metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> None: - _raise_deprecation_warning( + warnings_raised = _raise_deprecation_warning( + is_surrogate=True, botorch_model_class=botorch_model_class, model_options=model_options, mll_class=mll_class, @@ -389,43 +542,34 @@ def __init__( likelihood_class=likelihood_class, likelihood_options=likelihood_options, ) - model_configs = model_configs or [] - if len(model_configs) == 0: - model_config_kwargs = { - "botorch_model_class": botorch_model_class, - "model_options": (model_options or {}).copy(), - "mll_class": mll_class, - "mll_options": mll_options, - "outcome_transform_classes": outcome_transform_classes, - "outcome_transform_options": outcome_transform_options, - "input_transform_classes": input_transform_classes, - "input_transform_options": input_transform_options, - "covar_module_class": covar_module_class, - "covar_module_options": covar_module_options, - "likelihood_class": likelihood_class, - "likelihood_options": likelihood_options, - } - model_config_kwargs = { - k: v for k, v in model_config_kwargs.items() if v is not None - } - # pyre-fixme [6]: Incompatible parameter type [6]: In call - # `ModelConfig.__init__`, for 1st positional argument, expected - # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], - # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], - # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, - # Model]], Type[Likelihood], Type[Kernel]]`. - model_configs = [ModelConfig(**model_config_kwargs)] - self.model_configs: list[ModelConfig] = model_configs - self.metric_to_model_configs: dict[str, list[ModelConfig]] = ( - metric_to_model_configs or {} - ) - if len(self.model_configs) > 1 or any( - len(model_config) > 1 - for model_config in self.metric_to_model_configs.values() - ): - raise NotImplementedError("Only one model config per metric is supported.") + # check if surrogate_spec is provided + if surrogate_spec is None: + # create surrogate spec from deprecated arguments + model_config = get_model_config_from_deprecated_args( + botorch_model_class=botorch_model_class, + model_options=model_options, + mll_class=mll_class, + mll_options=mll_options, + outcome_transform_classes=outcome_transform_classes, + outcome_transform_options=outcome_transform_options, + input_transform_classes=input_transform_classes, + input_transform_options=input_transform_options, + covar_module_class=covar_module_class, + covar_module_options=covar_module_options, + likelihood_class=likelihood_class, + likelihood_options=likelihood_options, + ) + surrogate_spec = SurrogateSpec( + model_configs=[model_config], allow_batched_models=allow_batched_models + ) + + elif warnings_raised: + raise UserInputError( + "model_configs and deprecated arguments were both specified. " + "Please use model_configs and remove deprecated arguments." + ) - self.allow_batched_models = allow_batched_models + self.surrogate_spec: SurrogateSpec = surrogate_spec # Store the last dataset used to fit the model for a given metric(s). # If the new dataset is identical, we will skip model fitting for that metric. # The keys are `tuple(dataset.outcome_names)`. @@ -445,11 +589,7 @@ def __init__( self._model: Model | None = None def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__}" - f" model_configs={self.model_configs}," - f" metric_to_model_configs={self.metric_to_model_configs}>" - ) + return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>" @property def model(self) -> Model: @@ -614,20 +754,22 @@ def fit( # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. if ( - len(self.model_configs) == 1 - and self.model_configs[0].botorch_model_class is None + len(self.surrogate_spec.model_configs) == 1 + and self.surrogate_spec.model_configs[0].botorch_model_class is None ): default_botorch_model_class = choose_model_class( datasets=datasets, search_space_digest=search_space_digest ) else: - default_botorch_model_class = self.model_configs[0].botorch_model_class + default_botorch_model_class = self.surrogate_spec.model_configs[ + 0 + ].botorch_model_class should_use_model_list = use_model_list( datasets=datasets, botorch_model_class=not_none(default_botorch_model_class), - model_configs=self.model_configs, - allow_batched_models=self.allow_batched_models, - metric_to_model_configs=self.metric_to_model_configs, + model_configs=self.surrogate_spec.model_configs, + allow_batched_models=self.surrogate_spec.allow_batched_models, + metric_to_model_configs=self.surrogate_spec.metric_to_model_configs, ) if not should_use_model_list and len(datasets) > 1: @@ -646,7 +788,7 @@ def fit( else: submodel_state_dict = state_dict model_config = None - if len(self.metric_to_model_configs) > 0: + if len(self.surrogate_spec.metric_to_model_configs) > 0: # if metric_to_model_configs is not empty, then # we are using a model list and each dataset # should have only one outcome. @@ -655,7 +797,7 @@ def fit( "Each dataset should have only one outcome when " "metric_to_model_configs is specified." ) - model_config_list = self.metric_to_model_configs.get( + model_config_list = self.surrogate_spec.metric_to_model_configs.get( dataset.outcome_names[0] ) @@ -663,7 +805,7 @@ def fit( if model_config_list is not None: model_config = model_config_list[0] if model_config is None: - model_config = self.model_configs[0] + model_config = self.surrogate_spec.model_configs[0] model = self._construct_model( dataset=dataset, search_space_digest=search_space_digest, @@ -828,10 +970,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: """Serialize attributes of this surrogate, to be passed back to it as kwargs on reinstantiation. """ - return { - "model_configs": self.model_configs, - "metric_to_model_configs": self.metric_to_model_configs, - } + return {"surrogate_spec": self.surrogate_spec} def _extract_construct_input_transform_args( self, model_config: ModelConfig, search_space_digest: SearchSpaceDigest diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 389b3072175..797719a0f5c 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -21,9 +21,8 @@ from ax.models.torch.botorch_modular.model import ( BoTorchModel, choose_botorch_acqf_class, - SurrogateSpec, ) -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( choose_model_class, construct_acquisition_and_optimizer_options, @@ -719,32 +718,17 @@ def test_evaluate_acquisition_function( def test_surrogate_model_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel( - surrogate_spec=SurrogateSpec( - botorch_model_kwargs={"some_option": "some_value"} - ) + surrogate_spec = SurrogateSpec( + botorch_model_kwargs={"some_option": "some_value"} ) + model = BoTorchModel(surrogate_spec=surrogate_spec) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) mock_init.assert_called_with( - botorch_model_class=None, - model_options={"some_option": "some_value"}, - mll_class=ExactMarginalLogLikelihood, - mll_options={}, - covar_module_class=None, - covar_module_options=None, - likelihood_class=None, - likelihood_options=None, - input_transform_classes=None, - input_transform_options=None, - outcome_transform_classes=None, - outcome_transform_options=None, - model_configs=[], - metric_to_model_configs={}, - allow_batched_models=True, + surrogate_spec=surrogate_spec, ) @mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) @@ -753,28 +737,15 @@ def test_surrogate_model_options_propagation( def test_surrogate_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel(surrogate_spec=SurrogateSpec(allow_batched_models=False)) + surrogate_spec = SurrogateSpec(allow_batched_models=False) + model = BoTorchModel(surrogate_spec=surrogate_spec) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) mock_init.assert_called_with( - botorch_model_class=None, - model_options={}, - mll_class=ExactMarginalLogLikelihood, - mll_options={}, - covar_module_class=None, - covar_module_options=None, - likelihood_class=None, - likelihood_options=None, - input_transform_classes=None, - input_transform_options=None, - outcome_transform_classes=None, - outcome_transform_options=None, - model_configs=[], - metric_to_model_configs={}, - allow_batched_models=False, + surrogate_spec=surrogate_spec, ) @mock.patch( diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 6bbd600df90..eadce0951bf 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -19,7 +19,11 @@ from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate +from ax.models.torch.botorch_modular.surrogate import ( + _extract_model_kwargs, + Surrogate, + SurrogateSpec, +) from ax.models.torch.botorch_modular.utils import ( choose_model_class, fit_botorch_model, @@ -39,7 +43,7 @@ from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood from botorch.models.transforms.input import InputPerturbation, Normalize -from botorch.models.transforms.outcome import Standardize +from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import SupervisedDataset from gpytorch.constraints import GreaterThan, Interval from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel @@ -202,7 +206,7 @@ def _get_surrogate( mll_options = None if use_outcome_transform: - outcome_transform_classes = [Standardize] + outcome_transform_classes: list[type[OutcomeTransform]] = [Standardize] outcome_transform_options = {"Standardize": {"m": 1}} else: outcome_transform_classes = None @@ -222,10 +226,15 @@ def test_init(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) self.assertEqual( - surrogate.model_configs[0].botorch_model_class, botorch_model_class + surrogate.surrogate_spec.model_configs[0].botorch_model_class, + botorch_model_class, + ) + self.assertEqual( + surrogate.surrogate_spec.model_configs[0].mll_class, self.mll_class ) - self.assertEqual(surrogate.model_configs[0].mll_class, self.mll_class) - self.assertTrue(surrogate.allow_batched_models) # True by default + self.assertTrue( + surrogate.surrogate_spec.allow_batched_models + ) # True by default def test_clone_reset(self) -> None: surrogate = self._get_surrogate(botorch_model_class=SingleTaskGP)[0] @@ -454,7 +463,7 @@ def test_construct_model(self) -> None: model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - model_config=surrogate.model_configs[0], + model_config=surrogate.surrogate_spec.model_configs[0], default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, @@ -480,7 +489,7 @@ def test_construct_model(self) -> None: new_model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - model_config=surrogate.model_configs[0], + model_config=surrogate.surrogate_spec.model_configs[0], default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, @@ -517,7 +526,11 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None: "likelihood_class": FixedNoiseGaussianLikelihood, } if use_model_config: - surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(**model_config_kwargs)] + ) + ) else: surrogate = Surrogate(**model_config_kwargs) with self.assertRaisesRegex(UserInputError, "does not support"): @@ -536,7 +549,11 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None: "likelihood_options": {"noise_constraint": noise_constraint}, } if use_model_config: - surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(**model_config_kwargs)] + ) + ) else: surrogate = Surrogate(**model_config_kwargs) surrogate.fit( @@ -552,7 +569,8 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None: noise_constraint.__dict__, ) self.assertEqual( - surrogate.model_configs[0].mll_class, LeaveOneOutPseudoLikelihood + surrogate.surrogate_spec.model_configs[0].mll_class, + LeaveOneOutPseudoLikelihood, ) self.assertEqual(type(model.covar_module), RBFKernel) self.assertEqual(model.covar_module.ard_num_dims, 3) @@ -562,11 +580,13 @@ def test_construct_custom_model_with_config(self) -> None: def test_construct_model_with_metric_to_model_configs(self) -> None: surrogate = Surrogate( - metric_to_model_configs={ - "metric": [ModelConfig()], - "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], - }, - model_configs=[ModelConfig(covar_module_class=LinearKernel)], + surrogate_spec=SurrogateSpec( + metric_to_model_configs={ + "metric": [ModelConfig()], + "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], + }, + model_configs=[ModelConfig(covar_module_class=LinearKernel)], + ) ) training_data = self.training_data + [ SupervisedDataset( @@ -600,13 +620,13 @@ def test_multiple_model_configs_error(self) -> None: with self.assertRaisesRegex( NotImplementedError, "Only one model config per metric is supported." ): - Surrogate( + SurrogateSpec( model_configs=[ModelConfig(), ModelConfig()], ) with self.assertRaisesRegex( NotImplementedError, "Only one model config per metric is supported." ): - Surrogate( + SurrogateSpec( metric_to_model_configs={"metric": [ModelConfig(), ModelConfig()]}, ) @@ -720,10 +740,7 @@ def test_best_out_of_sample_point(self) -> None: def test_serialize_attributes_as_kwargs(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) - expected = { - "model_configs": surrogate.model_configs, - "metric_to_model_configs": surrogate.metric_to_model_configs, - } + expected = {"surrogate_spec": surrogate.surrogate_spec} self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected) @mock_botorch_optimize @@ -752,7 +769,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.model_configs[0].input_transform_classes = [Normalize] + surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.training_data, @@ -905,7 +922,7 @@ def setUp(self) -> None: ) def test_init(self) -> None: - model_config = self.surrogate.model_configs[0] + model_config = self.surrogate.surrogate_spec.model_configs[0] self.assertEqual( [model_config.botorch_model_class] * 2, [*self.botorch_submodel_class_per_outcome.values()], @@ -925,7 +942,9 @@ def test_init(self) -> None: def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock ) -> None: - self.surrogate.model_configs[0].model_options.update({"output_tasks": [2]}) + self.surrogate.surrogate_spec.model_configs[0].model_options.update( + {"output_tasks": [2]} + ) for fixed_noise in (False, True): mock_fit.reset_mock() mock_MTGP_construct_inputs.reset_mock() @@ -1021,7 +1040,7 @@ def test_fit( # pyre-ignore[6]: Incompatible parameter type: In call # `issubclass`, for 1st positional argument, expected # `Type[typing.Any]` but got `Optional[Type[Model]]`. - surrogate.model_configs[0].botorch_model_class, + surrogate.surrogate_spec.model_configs[0].botorch_model_class, MultiTaskGP, ) search_space_digest = ( @@ -1206,7 +1225,8 @@ def test_construct_custom_model(self) -> None: models = checked_cast(ModelListGP, surrogate._model).models self.assertEqual(len(models), 2) self.assertEqual( - surrogate.model_configs[0].mll_class, ExactMarginalLogLikelihood + surrogate.surrogate_spec.model_configs[0].mll_class, + ExactMarginalLogLikelihood, ) # Make sure we properly copied the transforms. self.assertNotEqual( @@ -1260,7 +1280,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.model_configs[0].input_transform_classes = [Normalize] + surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.supervised_training_data, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 5b31efb70f8..140ad1b33bd 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -44,8 +44,7 @@ TransitionCriterion, TrialBasedCriterion, ) -from ax.models.torch.botorch_modular.model import SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig from ax.storage.json_store.decoders import ( batch_trial_from_json, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 02498d03669..3f0aad31e66 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -96,8 +96,8 @@ TransitionCriterion, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 24c8a634e57..a2e17f95acb 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -600,11 +600,11 @@ def test_encode_decode_surrogate_spec(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) - org_as_dict = dataclasses.asdict(org_object) - converted_as_dict = dataclasses.asdict(converted_object) + org_as_dict = dataclasses.asdict(org_object)["model_configs"][0] + converted_as_dict = dataclasses.asdict(converted_object)["model_configs"][0] # Covar module kwargs will fail comparison. Manually compare. - org_covar_kwargs = org_as_dict.pop("covar_module_kwargs") - converted_covar_kwargs = converted_as_dict.pop("covar_module_kwargs") + org_covar_kwargs = org_as_dict.pop("covar_module_options") + converted_covar_kwargs = converted_as_dict.pop("covar_module_options") self.assertEqual(org_covar_kwargs.keys(), converted_covar_kwargs.keys()) for k in org_covar_kwargs: org_ = org_covar_kwargs[k] diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 6aaf96ab9de..8e130319fe2 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -28,7 +28,7 @@ from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.service.scheduler import SchedulerOptions from ax.utils.common.constants import Keys from ax.utils.testing.core_stubs import ( @@ -189,7 +189,11 @@ def get_sobol_gpei_benchmark_method() -> BenchmarkMethod: model=Models.BOTORCH_MODULAR, num_trials=-1, model_kwargs={ - "surrogate": Surrogate(SingleTaskGP), + "surrogate": Surrogate( + surrogate_spec=SurrogateSpec( + botorch_model_class=SingleTaskGP + ) + ), # TODO: tests should better reflect defaults and not # re-implement this logic. "botorch_acqf_class": qNoisyExpectedImprovement, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7d90a339df1..007375df242 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -99,9 +99,9 @@ ) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec +from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.sebo import SEBOAcquisition -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index da21088c63c..faefc62239f 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -24,19 +24,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904956474, - "executionStopTime": 1730904963470, - "id": "about-preview", + "originalKey": "91bb9820-6624-44d6-9659-76e883624f80", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" - }, - "originalKey": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7", - "outputsInitialized": true, - "requestMsgId": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7", - "serverExecutionDuration": 4351.2808320811 + "executionStartTime": 1730900850726, + "executionStopTime": 1730900856668, + "serverExecutionDuration": 4261.2324819202, + "collapsed": false, + "requestMsgId": "91bb9820-6624-44d6-9659-76e883624f80" }, "source": [ "from typing import Any, Dict, Optional, Tuple, Type\n", @@ -76,14 +72,14 @@ "output_type": "stream", "name": "stderr", "text": [ - "I1106 065557.352 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" + "I1106 054731.719 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "I1106 065557.353 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" + "I1106 054731.720 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" ] } ] @@ -91,15 +87,11 @@ { "cell_type": "markdown", "metadata": { - "id": "northern-affairs", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" - }, - "originalKey": "c9a665ca-497e-4d7c-bbb5-1b9f8d1d311c", + "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "# Setup and Usage of BoTorch Models in Ax\n", @@ -125,15 +117,11 @@ { "cell_type": "markdown", "metadata": { - "id": "pending-support", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" - }, - "originalKey": "4706d02e-6b3f-4161-9e08-f5a31328b1d1", + "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 1. Quick-start example\n", @@ -144,19 +132,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904958719, - "executionStopTime": 1730904963489, - "id": "parental-sending", + "originalKey": "07947d7f-3391-41d3-a089-da77c4f0af83", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" - }, - "originalKey": "146a9a1b-52e6-4d76-9fc5-79025b392673", - "outputsInitialized": true, - "requestMsgId": "146a9a1b-52e6-4d76-9fc5-79025b392673", - "serverExecutionDuration": 42.191333021037 + "executionStartTime": 1730900854173, + "executionStopTime": 1730900856687, + "serverExecutionDuration": 26.501741958782, + "collapsed": false, + "requestMsgId": "07947d7f-3391-41d3-a089-da77c4f0af83" }, "source": [ "experiment = get_branin_experiment(with_trial=True)\n", @@ -168,7 +152,7 @@ "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:56:01] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + "[INFO 11-06 05:47:35] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" ] } ] @@ -176,37 +160,38 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904959343, - "executionStopTime": 1730904964892, - "id": "rough-somerset", + "originalKey": "8ca4cd63-4d6c-467b-a107-07d902e714fb", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" - }, - "originalKey": "aa532754-01ad-4441-84c1-2ac7f54ecf1e", - "outputsInitialized": true, - "requestMsgId": "aa532754-01ad-4441-84c1-2ac7f54ecf1e", - "serverExecutionDuration": 870.78339292202 + "executionStartTime": 1730901083682, + "executionStopTime": 1730901083970, + "serverExecutionDuration": 11.818251106888, + "collapsed": false, + "requestMsgId": "8ca4cd63-4d6c-467b-a107-07d902e714fb" }, "source": [ + "from ax.models.torch.botorch_modular.surrogate import SurrogateSpec\n", + "from ax.models.torch.botorch_modular.utils import ModelConfig\n", + "\n", "# `Models` automatically selects a model + model bridge combination.\n", "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", " data=data,\n", - " surrogate=Surrogate(SingleTaskGP), # Optional, will use default if unspecified\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " ), # Optional, will use default if unspecified\n", " botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified\n", ")" ], - "execution_count": 3, + "execution_count": 21, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:56:01] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + "[INFO 11-06 05:51:23] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" ] } ] @@ -214,15 +199,11 @@ { "cell_type": "markdown", "metadata": { - "id": "hairy-wiring", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" - }, - "originalKey": "46f5c2c7-400d-4d8d-b0b9-a241657b173f", + "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." @@ -231,48 +212,31 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730904961333, - "executionStopTime": 1730904964907, - "id": "consecutive-summary", + "originalKey": "0fec8fcd-8c14-4ad8-8aa0-fe066d37f2d9", + "outputsInitialized": false, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" - }, - "originalKey": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696", - "outputsInitialized": true, - "requestMsgId": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696", - "serverExecutionDuration": 284.31268292479 + "executionStartTime": 1730900855362, + "executionStopTime": 1730900857172, + "serverExecutionDuration": null, + "collapsed": false, + "requestMsgId": "0fec8fcd-8c14-4ad8-8aa0-fe066d37f2d9" }, "source": [ "generator_run = model_bridge_with_GPEI.gen(n=1)\n", "generator_run.arms[0]" ], - "execution_count": 4, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": "Arm(parameters={'x1': -5.0, 'x2': 0.0})" - }, - "metadata": {}, - "execution_count": 4 - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "diverse-richards", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" - }, - "originalKey": "804bac30-db07-4444-98a2-7a5f05007495", + "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "-----\n", @@ -285,34 +249,26 @@ { "cell_type": "markdown", "metadata": { - "id": "grand-committee", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" - }, - "originalKey": "31b54ce5-2590-4617-b10c-d24ed3cce51d", + "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 2. BoTorchModel = Surrogate + Acquisition\n", "\n", - "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." + "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The Surrogate is defined by a `SurrogateSpec`. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." ] }, { "cell_type": "markdown", "metadata": { - "id": "thousand-blanket", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" - }, - "originalKey": "4a4e006e-07fa-4d63-8b9a-31b67075e40e", + "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 2A. Example that uses defaults and requires no options\n", @@ -323,19 +279,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905022012, - "executionStopTime": 1730905022269, - "id": "changing-xerox", + "originalKey": "1530b9af-eab5-4fd4-8e9f-c586c40ea15d", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" - }, - "originalKey": "509e30d7-dc32-4190-836f-f221cacbff31", - "outputsInitialized": true, - "requestMsgId": "509e30d7-dc32-4190-836f-f221cacbff31", - "serverExecutionDuration": 1.972567057237 + "executionStartTime": 1730900857131, + "executionStopTime": 1730900858332, + "serverExecutionDuration": 2.7728870045394, + "collapsed": false, + "requestMsgId": "1530b9af-eab5-4fd4-8e9f-c586c40ea15d" }, "source": [ "from ax.models.torch.botorch_modular.utils import ModelConfig\n", @@ -353,21 +305,17 @@ "# Both the surrogate and acquisition class will be auto-selected.\n", "GPEI_model = BoTorchModel()" ], - "execution_count": 7, + "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "lovely-mechanics", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" - }, - "originalKey": "7b9fae38-fe5d-4e5b-8b5f-2953c1ef09d2", + "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 2B. Example with all the options\n", @@ -377,24 +325,20 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905023369, - "executionStopTime": 1730905023622, - "id": "twenty-greek", + "originalKey": "706bd6c9-5311-4c28-9a57-d9f1b2e49c14", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" - }, - "originalKey": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03", - "outputsInitialized": true, - "requestMsgId": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03", - "serverExecutionDuration": 1.9794970285147 + "executionStartTime": 1730901117832, + "executionStopTime": 1730901118116, + "serverExecutionDuration": 2.0907880971208, + "collapsed": false, + "requestMsgId": "706bd6c9-5311-4c28-9a57-d9f1b2e49c14" }, "source": [ "model = BoTorchModel(\n", " # Optional `Surrogate` specification to use instead of default\n", - " surrogate=Surrogate(\n", + " surrogate_spec=SurrogateSpec(\n", " model_configs=[\n", " ModelConfig(\n", " # BoTorch `Model` type\n", @@ -421,21 +365,17 @@ " warm_start_refit=True,\n", ")" ], - "execution_count": 8, + "execution_count": 23, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "fourth-material", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" - }, - "originalKey": "7140bb19-09b4-4abe-951d-53902ae07833", + "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 2C. `Surrogate` and `Acquisition` Q&A\n", @@ -450,15 +390,11 @@ { "cell_type": "markdown", "metadata": { - "id": "violent-course", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" - }, - "originalKey": "71f92895-874d-4fc7-ae87-a5519b18d1a0", + "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" @@ -467,18 +403,13 @@ { "cell_type": "markdown", "metadata": { - "id": "unlike-football", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", - "showInput": false - }, - "originalKey": "4af8afa2-5056-46be-b7b9-428127e668cc", + "code_folding": [], + "hidden_ranges": [], + "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 3a. Making a `Surrogate` from BoTorch `Model`:\n", @@ -486,27 +417,23 @@ "\n", "If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:\n", "1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/main/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/main/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept training data in form of a `SupervisedDataset` container and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.\n", - "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`." + "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via the `model_options` argument to `ModelConfig` in `SurrogateSpec`." ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905050779, - "executionStopTime": 1730905051022, - "id": "dynamic-university", + "code_folding": [], + "hidden_ranges": [], + "originalKey": "dfb20ccf-5405-49c7-bf5b-dff40ed87dd5", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" - }, - "originalKey": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7", - "outputsInitialized": true, - "requestMsgId": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7", - "serverExecutionDuration": 2.5736939860508 + "executionStartTime": 1730900860005, + "executionStopTime": 1730900870858, + "serverExecutionDuration": 2.5424950290471, + "collapsed": false, + "requestMsgId": "dfb20ccf-5405-49c7-bf5b-dff40ed87dd5" }, "source": [ "from botorch.utils.datasets import SupervisedDataset\n", @@ -531,30 +458,27 @@ "\n", "\n", "surrogate = Surrogate(\n", - " model_configs=[\n", - " ModelConfig(\n", - " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", - " # Optional dict of additional keyword arguments to `MyModelClass`\n", - " model_options={},\n", - " )\n", - " ],\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[\n", + " ModelConfig(\n", + " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", + " model_options={}, # Optional dict of additional keyword arguments to `MyModelClass`\n", + " )\n", + " ]\n", + " ),\n", ")" ], - "execution_count": 9, + "execution_count": 6, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "otherwise-context", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" - }, - "originalKey": "5a27fd2c-4c4c-41fe-a634-f6d0ec4f1666", + "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." @@ -563,15 +487,11 @@ { "cell_type": "markdown", "metadata": { - "id": "northern-invite", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" - }, - "originalKey": "df06d02b-95cb-4d34-aac6-773231f1a129", + "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" @@ -580,18 +500,13 @@ { "cell_type": "markdown", "metadata": { - "id": "surrounded-denial", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", - "showInput": false - }, - "originalKey": "d4861847-b757-4fcd-9f35-ba258080812c", + "code_folding": [], + "hidden_ranges": [], + "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Steps to set up any `AcquisitionFunction` in Ax are:\n", @@ -604,21 +519,17 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905067704, - "executionStopTime": 1730905067939, - "id": "interested-search", + "code_folding": [], + "hidden_ranges": [], + "originalKey": "4212ac28-bdbb-4d88-a098-9f6a833511b5", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" - }, - "originalKey": "075243f0-c2d6-46fd-9a18-5b77af258abf", - "outputsInitialized": true, - "requestMsgId": "075243f0-c2d6-46fd-9a18-5b77af258abf", - "serverExecutionDuration": 5.0081580411643 + "executionStartTime": 1730900862795, + "executionStopTime": 1730900871963, + "serverExecutionDuration": 6.7860540002584, + "collapsed": false, + "requestMsgId": "4212ac28-bdbb-4d88-a098-9f6a833511b5" }, "source": [ "from botorch.acquisition.acquisition import AcquisitionFunction\n", @@ -655,7 +566,7 @@ " },\n", ")" ], - "execution_count": 10, + "execution_count": 7, "outputs": [ { "output_type": "execute_result", @@ -663,22 +574,18 @@ "text/plain": "BoTorchModel" }, "metadata": {}, - "execution_count": 10 + "execution_count": 7 } ] }, { "cell_type": "markdown", "metadata": { - "id": "metallic-imaging", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" - }, - "originalKey": "b057722d-b8ca-47dd-b2c8-1ff4a71c4863", + "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." @@ -687,15 +594,11 @@ { "cell_type": "markdown", "metadata": { - "id": "descending-australian", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" - }, - "originalKey": "a7406f13-1468-487d-ac5e-7d2a45394850", + "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 4. Using `Models.BOTORCH_MODULAR` \n", @@ -708,19 +611,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905070804, - "executionStopTime": 1730905071313, - "id": "attached-border", + "originalKey": "f2207909-ac67-40d5-a370-ffe64cf15515", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" - }, - "originalKey": "980d8513-8607-4099-8cfe-7c8d7bf5afe9", - "outputsInitialized": true, - "requestMsgId": "980d8513-8607-4099-8cfe-7c8d7bf5afe9", - "serverExecutionDuration": 262.54303695168 + "executionStartTime": 1730900864711, + "executionStopTime": 1730900874210, + "serverExecutionDuration": 1119.7519529378, + "collapsed": false, + "requestMsgId": "f2207909-ac67-40d5-a370-ffe64cf15515" }, "source": [ "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", @@ -729,13 +628,13 @@ ")\n", "model_bridge_with_GPEI.gen(1)" ], - "execution_count": 11, + "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:51] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + "[INFO 11-06 05:47:52] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" ] }, { @@ -744,31 +643,27 @@ "text/plain": "GeneratorRun(1 arms, total weight 1.0)" }, "metadata": {}, - "execution_count": 11 + "execution_count": 8 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905071744, - "executionStopTime": 1730905071981, - "id": "powerful-gamma", + "originalKey": "ea6f8fe7-35bf-439e-83c8-c49ffa7ad8ea", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "89930a31-e058-434b-b587-181931e247b6" - }, - "originalKey": "6ec047f0-a75e-4733-b4d7-20045627f0b2", - "outputsInitialized": true, - "requestMsgId": "6ec047f0-a75e-4733-b4d7-20045627f0b2", - "serverExecutionDuration": 2.8772989753634 + "executionStartTime": 1730900865310, + "executionStopTime": 1730900874224, + "serverExecutionDuration": 3.8121930556372, + "collapsed": false, + "requestMsgId": "ea6f8fe7-35bf-439e-83c8-c49ffa7ad8ea" }, "source": [ "model_bridge_with_GPEI.model.botorch_acqf_class" ], - "execution_count": 12, + "execution_count": 9, "outputs": [ { "output_type": "execute_result", @@ -776,31 +671,27 @@ "text/plain": "botorch.acquisition.logei.qLogNoisyExpectedImprovement" }, "metadata": {}, - "execution_count": 12 + "execution_count": 9 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905076319, - "executionStopTime": 1730905076580, - "id": "improved-replication", + "originalKey": "945dff10-01fa-443a-845a-1d1890aee093", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" - }, - "originalKey": "68caa729-792d-4692-bc4d-4c0b8d03e022", - "outputsInitialized": true, - "requestMsgId": "68caa729-792d-4692-bc4d-4c0b8d03e022", - "serverExecutionDuration": 3.2039729412645 + "executionStartTime": 1730900925859, + "executionStopTime": 1730900926172, + "serverExecutionDuration": 3.6907501053065, + "collapsed": false, + "requestMsgId": "945dff10-01fa-443a-845a-1d1890aee093" }, "source": [ "type(model_bridge_with_GPEI.model.surrogate.model)" ], - "execution_count": 13, + "execution_count": 16, "outputs": [ { "output_type": "execute_result", @@ -808,22 +699,18 @@ "text/plain": "botorch.models.gp_regression.SingleTaskGP" }, "metadata": {}, - "execution_count": 13 + "execution_count": 16 } ] }, { "cell_type": "markdown", "metadata": { - "id": "connected-sheet", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" - }, - "originalKey": "f5c0adbd-00a6-428d-810f-1e7ed0954b08", + "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" @@ -832,19 +719,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905078786, - "executionStopTime": 1730905079551, - "id": "documentary-jurisdiction", + "originalKey": "a3306de2-1954-4619-a516-f62fbc4d2f9b", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" - }, - "originalKey": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be", - "outputsInitialized": true, - "requestMsgId": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be", - "serverExecutionDuration": 512.32317101676 + "executionStartTime": 1730900937698, + "executionStopTime": 1730900938557, + "serverExecutionDuration": 596.06455394533, + "collapsed": false, + "requestMsgId": "a3306de2-1954-4619-a516-f62fbc4d2f9b" }, "source": [ "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", @@ -855,27 +738,27 @@ ")\n", "model_bridge_with_EHVI.gen(1)" ], - "execution_count": 14, + "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:59] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + "[INFO 11-06 05:48:57] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" + "[INFO 11-06 05:48:58] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" + "[INFO 11-06 05:48:58] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" ] }, { @@ -884,31 +767,27 @@ "text/plain": "GeneratorRun(1 arms, total weight 1.0)" }, "metadata": {}, - "execution_count": 14 + "execution_count": 18 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905079324, - "executionStopTime": 1730905079651, - "id": "changed-maintenance", + "originalKey": "80833c65-3b68-492f-a417-619e8b18a91a", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" - }, - "originalKey": "87247bf1-04f2-4a8a-92f6-18174d70cbb7", - "outputsInitialized": true, - "requestMsgId": "87247bf1-04f2-4a8a-92f6-18174d70cbb7", - "serverExecutionDuration": 3.0141109600663 + "executionStartTime": 1730900938324, + "executionStopTime": 1730900938627, + "serverExecutionDuration": 3.1278820242733, + "collapsed": false, + "requestMsgId": "80833c65-3b68-492f-a417-619e8b18a91a" }, "source": [ "model_bridge_with_EHVI.model.botorch_acqf_class" ], - "execution_count": 15, + "execution_count": 19, "outputs": [ { "output_type": "execute_result", @@ -916,31 +795,27 @@ "text/plain": "botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement" }, "metadata": {}, - "execution_count": 15 + "execution_count": 19 } ] }, { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905087115, - "executionStopTime": 1730905087362, - "id": "operating-shelf", + "originalKey": "f7645d44-0e38-4969-8360-547ad863bdbe", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" - }, - "originalKey": "22fefbbf-68a9-4d5d-ade9-9df425995c3b", - "outputsInitialized": true, - "requestMsgId": "22fefbbf-68a9-4d5d-ade9-9df425995c3b", - "serverExecutionDuration": 3.2659249845892 + "executionStartTime": 1730900939089, + "executionStopTime": 1730900939384, + "serverExecutionDuration": 3.1960729975253, + "collapsed": false, + "requestMsgId": "f7645d44-0e38-4969-8360-547ad863bdbe" }, "source": [ "type(model_bridge_with_EHVI.model.surrogate.model)" ], - "execution_count": 16, + "execution_count": 20, "outputs": [ { "output_type": "execute_result", @@ -948,22 +823,17 @@ "text/plain": "botorch.models.gp_regression.SingleTaskGP" }, "metadata": {}, - "execution_count": 16 + "execution_count": 20 } ] }, { "cell_type": "markdown", "metadata": { - "id": "fatal-butterfly", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" - }, - "originalKey": "3ad7c4a7-fe19-44ad-938d-1be4f8b09bfb", + "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " @@ -972,15 +842,10 @@ { "cell_type": "markdown", "metadata": { - "id": "hearing-interface", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" - }, - "originalKey": "44adf1ce-6d3e-455d-b53c-32d3c42a843f", + "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 5. Utilizing `BoTorchModel` in generation strategies\n", @@ -993,19 +858,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905104120, - "executionStopTime": 1730905104377, - "id": "received-registration", + "originalKey": "6505e024-7b59-4dd6-8bc7-7b9b7c660a4a", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" - }, - "originalKey": "d0303c24-98bb-4c89-87cb-fa32ff498bd4", - "outputsInitialized": true, - "requestMsgId": "d0303c24-98bb-4c89-87cb-fa32ff498bd4", - "serverExecutionDuration": 2.4519630242139 + "executionStartTime": 1730901513273, + "executionStopTime": 1730901513538, + "serverExecutionDuration": 2.3842720547691, + "collapsed": false, + "requestMsgId": "6505e024-7b59-4dd6-8bc7-7b9b7c660a4a" }, "source": [ "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", @@ -1028,7 +889,7 @@ " # No limit on how many generator runs will be produced\n", " num_trials=-1,\n", " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", - " \"surrogate\": Surrogate(\n", + " \"surrogate_spec\": SurrogateSpec(\n", " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", " ),\n", " \"botorch_acqf_class\": qLogNoisyExpectedImprovement,\n", @@ -1037,21 +898,17 @@ " ]\n", ")" ], - "execution_count": 17, + "execution_count": 24, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "logical-windsor", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" - }, - "originalKey": "ba3783ee-3d88-4e44-ad07-77de3c50f84d", + "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" @@ -1060,19 +917,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905105295, - "executionStopTime": 1730905105570, - "id": "viral-cheese", + "originalKey": "2eaa7f7b-872f-480a-b8fb-db55911bf266", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" - }, - "originalKey": "e8aa4013-eb6f-4f50-a5f0-f963369495ed", - "outputsInitialized": true, - "requestMsgId": "e8aa4013-eb6f-4f50-a5f0-f963369495ed", - "serverExecutionDuration": 4.1721769375727 + "executionStartTime": 1730901515256, + "executionStopTime": 1730901515518, + "serverExecutionDuration": 3.7259008968249, + "collapsed": false, + "requestMsgId": "2eaa7f7b-872f-480a-b8fb-db55911bf266" }, "source": [ "experiment = get_branin_experiment(minimize=True)\n", @@ -1080,13 +933,13 @@ "assert len(experiment.trials) == 0\n", "experiment.search_space" ], - "execution_count": 18, + "execution_count": 25, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:58:25] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + "[INFO 11-06 05:58:35] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" ] }, { @@ -1095,22 +948,18 @@ "text/plain": "SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])" }, "metadata": {}, - "execution_count": 18 + "execution_count": 25 } ] }, { "cell_type": "markdown", "metadata": { - "id": "incident-newspaper", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" - }, - "originalKey": "df2e90f5-4132-4d87-989b-e6d47c748ddc", + "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 5a. Specifying `pending_observations`\n", @@ -1122,19 +971,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905107391, - "executionStopTime": 1730905109351, - "id": "casual-spread", + "originalKey": "8c1d5a77-ace0-4b93-ba5a-f019f13653ac", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" - }, - "originalKey": "36ca75ec-b37c-498c-a487-5652cd3dc34b", - "outputsInitialized": true, - "requestMsgId": "36ca75ec-b37c-498c-a487-5652cd3dc34b", - "serverExecutionDuration": 1696.8911510194 + "executionStartTime": 1730901517812, + "executionStopTime": 1730901519932, + "serverExecutionDuration": 1877.3847029079, + "collapsed": false, + "requestMsgId": "8c1d5a77-ace0-4b93-ba5a-f019f13653ac" }, "source": [ "for _ in range(10):\n", @@ -1155,7 +1000,7 @@ "\n", " print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")" ], - "execution_count": 19, + "execution_count": 26, "outputs": [ { "output_type": "stream", @@ -1171,6 +1016,13 @@ "Completed trial #5, suggested by BoTorch.\n" ] }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #6, suggested by BoTorch.\n" + ] + }, { "output_type": "stream", "name": "stdout", @@ -1197,15 +1049,11 @@ { "cell_type": "markdown", "metadata": { - "id": "circular-vermont", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" - }, - "originalKey": "6a78ef13-fbaa-4cae-934b-d57f5807fe25", + "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" @@ -1214,37 +1062,33 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905108832, - "executionStopTime": 1730905109399, - "id": "significant-particular", + "originalKey": "22d2dd61-03b4-4f40-af94-09b2cfcc7a7e", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" - }, - "originalKey": "68d567be-27d6-4244-b22f-d6c53ed2d303", - "outputsInitialized": true, - "requestMsgId": "68d567be-27d6-4244-b22f-d6c53ed2d303", - "serverExecutionDuration": 32.219067099504 + "executionStartTime": 1730901518750, + "executionStopTime": 1730901519982, + "serverExecutionDuration": 32.345620915294, + "collapsed": false, + "requestMsgId": "22d2dd61-03b4-4f40-af94-09b2cfcc7a7e" }, "source": [ "exp_to_df(experiment)" ], - "execution_count": 20, + "execution_count": 27, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[WARNING 11-06 06:58:29] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" + "[WARNING 11-06 05:58:39] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" ] }, { "output_type": "execute_result", "data": { - "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 79.581199 1.380743 12.280850\n1 1 1_0 COMPLETED ... 17.366840 6.989676 1.438049\n2 2 2_0 COMPLETED ... 61.299075 6.097525 7.568626\n3 3 3_0 COMPLETED ... 71.268812 -3.293570 4.231312\n4 4 4_0 COMPLETED ... 3.831238 -2.268755 10.230113\n5 5 5_0 COMPLETED ... 4.246417 -3.354258 10.886093\n6 6 6_0 COMPLETED ... 6.712767 9.467421 0.000000\n7 7 7_0 COMPLETED ... 17.508300 -5.000000 15.000000\n8 8 8_0 COMPLETED ... 2.507635 10.000000 2.251628\n9 9 9_0 COMPLETED ... 1.318731 8.983844 2.177537\n\n[10 rows x 7 columns]", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
trial_indexarm_nametrial_statusgeneration_methodbraninx1x2
000_0COMPLETEDSobol79.5811991.38074312.280850
111_0COMPLETEDSobol17.3668406.9896761.438049
222_0COMPLETEDSobol61.2990756.0975257.568626
333_0COMPLETEDSobol71.268812-3.2935704.231312
444_0COMPLETEDSobol3.831238-2.26875510.230113
555_0COMPLETEDBoTorch4.246417-3.35425810.886093
666_0COMPLETEDBoTorch6.7127679.4674210.000000
777_0COMPLETEDBoTorch17.508300-5.00000015.000000
888_0COMPLETEDBoTorch2.50763510.0000002.251628
999_0COMPLETEDBoTorch1.3187318.9838442.177537
\n
", + "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 45.658452 -1.388536 14.282915\n1 1 1_0 COMPLETED ... 20.676720 5.753865 2.664872\n2 2 2_0 COMPLETED ... 56.838245 8.953143 9.548486\n3 3 3_0 COMPLETED ... 22.800902 0.861398 7.283473\n4 4 4_0 COMPLETED ... 32.887329 0.426293 9.105935\n5 5 5_0 COMPLETED ... 16.584283 0.914660 3.800049\n6 6 6_0 COMPLETED ... 97.321924 -4.299831 5.683428\n7 7 7_0 COMPLETED ... 2.589939 2.975172 3.843420\n8 8 8_0 COMPLETED ... 16.611092 3.509416 5.951630\n9 9 9_0 COMPLETED ... 4.862684 3.524982 0.053942\n\n[10 rows x 7 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
trial_indexarm_nametrial_statusgeneration_methodbraninx1x2
000_0COMPLETEDSobol45.658452-1.38853614.282915
111_0COMPLETEDSobol20.6767205.7538652.664872
222_0COMPLETEDSobol56.8382458.9531439.548486
333_0COMPLETEDSobol22.8009020.8613987.283473
444_0COMPLETEDSobol32.8873290.4262939.105935
555_0COMPLETEDBoTorch16.5842830.9146603.800049
666_0COMPLETEDBoTorch97.321924-4.2998315.683428
777_0COMPLETEDBoTorch2.5899392.9751723.843420
888_0COMPLETEDBoTorch16.6110923.5094165.951630
999_0COMPLETEDBoTorch4.8626843.5249820.053942
\n
", "application/vnd.dataresource+json": { "schema": { "fields": [ @@ -1293,9 +1137,9 @@ "arm_name": "0_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 79.5811993025, - "x1": 1.3807432353, - "x2": 12.2808498144 + "branin": 45.6584524922, + "x1": -1.3885358721, + "x2": 14.2829149961 }, { "index": 1, @@ -1303,9 +1147,9 @@ "arm_name": "1_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 17.3668397479, - "x1": 6.9896756578, - "x2": 1.4380489429 + "branin": 20.6767197174, + "x1": 5.7538650231, + "x2": 2.6648724033 }, { "index": 2, @@ -1313,9 +1157,9 @@ "arm_name": "2_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 61.2990749448, - "x1": 6.097525456, - "x2": 7.5686264969 + "branin": 56.8382447759, + "x1": 8.9531433629, + "x2": 9.5484864246 }, { "index": 3, @@ -1323,9 +1167,9 @@ "arm_name": "3_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 71.268812081, - "x1": -3.2935703546, - "x2": 4.231311623 + "branin": 22.8009019164, + "x1": 0.8613982378, + "x2": 7.2834726842 }, { "index": 4, @@ -1333,9 +1177,9 @@ "arm_name": "4_0", "trial_status": "COMPLETED", "generation_method": "Sobol", - "branin": 3.8312383283, - "x1": -2.2687551333, - "x2": 10.2301133936 + "branin": 32.8873287104, + "x1": 0.4262934485, + "x2": 9.1059345799 }, { "index": 5, @@ -1343,9 +1187,9 @@ "arm_name": "5_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 4.2464169491, - "x1": -3.3542581623, - "x2": 10.8860926765 + "branin": 16.5842834835, + "x1": 0.9146598153, + "x2": 3.8000487056 }, { "index": 6, @@ -1353,9 +1197,9 @@ "arm_name": "6_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 6.7127667696, - "x1": 9.4674207228, - "x2": 0 + "branin": 97.3219235984, + "x1": -4.2998308456, + "x2": 5.6834276208 }, { "index": 7, @@ -1363,9 +1207,9 @@ "arm_name": "7_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 17.5082995158, - "x1": -5, - "x2": 15 + "branin": 2.5899390643, + "x1": 2.9751721652, + "x2": 3.8434196708 }, { "index": 8, @@ -1373,9 +1217,9 @@ "arm_name": "8_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 2.5076350936, - "x1": 10, - "x2": 2.2516281613 + "branin": 16.6110923305, + "x1": 3.5094163182, + "x2": 5.9516303038 }, { "index": 9, @@ -1383,30 +1227,26 @@ "arm_name": "9_0", "trial_status": "COMPLETED", "generation_method": "BoTorch", - "branin": 1.3187305399, - "x1": 8.983844442, - "x2": 2.1775366828 + "branin": 4.8626836874, + "x1": 3.5249824947, + "x2": 0.0539416175 } ] } }, "metadata": {}, - "execution_count": 20 + "execution_count": 27 } ] }, { "cell_type": "markdown", "metadata": { - "id": "obvious-transparency", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783" - }, - "originalKey": "633c66af-a89f-4f03-a88b-866767d0a52f", + "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 6. Customizing a `Surrogate` or `Acquisition`\n", @@ -1419,19 +1259,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905111356, - "executionStopTime": 1730905111601, - "id": "organizational-balance", + "originalKey": "df4a44fd-1d24-45f5-9918-c8590c32cbd7", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af" - }, - "originalKey": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693", - "outputsInitialized": true, - "requestMsgId": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693", - "serverExecutionDuration": 2.4916339898482 + "executionStartTime": 1730901522195, + "executionStopTime": 1730901522454, + "serverExecutionDuration": 2.353421994485, + "collapsed": false, + "requestMsgId": "df4a44fd-1d24-45f5-9918-c8590c32cbd7" }, "source": [ "from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform\n", @@ -1451,21 +1287,17 @@ " ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:\n", " ... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default" ], - "execution_count": 21, + "execution_count": 28, "outputs": [] }, { "cell_type": "markdown", "metadata": { - "id": "theoretical-horizon", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9" - }, - "originalKey": "0ec8606d-9d5b-4bcb-ad7e-f54839ad6f9b", + "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9", + "showInput": false, "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Then to use the new subclass in `BoTorchModel`, just specify `acquisition_class` argument along with `botorch_acqf_class` (to `BoTorchModel` directly or to `Models.BOTORCH_MODULAR`, which just passes the relevant arguments to `BoTorchModel` under the hood, as discussed in section 4):" @@ -1474,19 +1306,15 @@ { "cell_type": "code", "metadata": { - "collapsed": false, - "executionStartTime": 1730905113131, - "executionStopTime": 1730905113387, - "id": "approximate-rolling", + "originalKey": "3ad55d13-20eb-49a5-86bf-f521c76fb1bb", + "outputsInitialized": true, "isAgentGenerated": false, "language": "python", - "metadata": { - "originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc" - }, - "originalKey": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab", - "outputsInitialized": true, - "requestMsgId": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab", - "serverExecutionDuration": 12.22031598445 + "executionStartTime": 1730901522848, + "executionStopTime": 1730901523109, + "serverExecutionDuration": 12.289707083255, + "collapsed": false, + "requestMsgId": "3ad55d13-20eb-49a5-86bf-f521c76fb1bb" }, "source": [ "Models.BOTORCH_MODULAR(\n", @@ -1496,13 +1324,13 @@ " botorch_acqf_class=MyAcquisitionFunctionClass,\n", ")" ], - "execution_count": 22, + "execution_count": 29, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "[INFO 11-06 06:58:33] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + "[INFO 11-06 05:58:43] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" ] }, { @@ -1511,22 +1339,17 @@ "text/plain": "TorchModelBridge(model=BoTorchModel)" }, "metadata": {}, - "execution_count": 22 + "execution_count": 29 } ] }, { "cell_type": "markdown", "metadata": { - "id": "representative-implement", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256" - }, - "originalKey": "cdcfb2bc-3016-4681-9fff-407f28321c3f", + "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "To use a custom `Surrogate` subclass, pass the `surrogate` argument of that type:\n", @@ -1542,15 +1365,10 @@ { "cell_type": "markdown", "metadata": { - "id": "framed-intermediate", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361" - }, - "originalKey": "ff03d674-f584-403f-ba65-f1bab921845b", + "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "------" @@ -1559,15 +1377,10 @@ { "cell_type": "markdown", "metadata": { - "id": "metropolitan-feedback", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0" - }, - "originalKey": "f71fcfa1-fc59-4bfb-84d6-b94ea5298bfa", + "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## Appendix 1: Methods available on `BoTorchModel`\n", @@ -1590,15 +1403,10 @@ { "cell_type": "markdown", "metadata": { - "id": "possible-transsexual", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12" - }, - "originalKey": "91cedde4-8911-441f-af05-eb124581cbbc", + "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## Appendix 2: Default surrogate models and acquisition functions\n", @@ -1617,15 +1425,10 @@ { "cell_type": "markdown", "metadata": { - "id": "continuous-strain", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6" - }, - "originalKey": "c8b0f933-8df6-479b-aa61-db75ca877624", + "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A\n", @@ -1636,15 +1439,10 @@ { "cell_type": "markdown", "metadata": { - "id": "broadband-voice", - "isAgentGenerated": false, - "language": "markdown", - "metadata": { - "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc" - }, - "originalKey": "4d82f49a-3a8b-42f0-a4f5-5c079b793344", + "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc", "outputsInitialized": false, - "showInput": false + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "The two options for handling this error are:\n",