diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py
index 74518e3ebe7..d4fbeda78a2 100644
--- a/ax/benchmark/methods/modular_botorch.py
+++ b/ax/benchmark/methods/modular_botorch.py
@@ -14,7 +14,7 @@
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
-from ax.models.torch.botorch_modular.model import SurrogateSpec
+from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.service.scheduler import SchedulerOptions
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import LogExpectedImprovement
diff --git a/ax/benchmark/tests/methods/test_methods.py b/ax/benchmark/tests/methods/test_methods.py
index 9c63e2705e0..0dc1aa60cee 100644
--- a/ax/benchmark/tests/methods/test_methods.py
+++ b/ax/benchmark/tests/methods/test_methods.py
@@ -65,7 +65,7 @@ def _test_mbm_acquisition(self, scheduler_options: SchedulerOptions) -> None:
self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient)
surrogate_spec = model_kwargs["surrogate_spec"]
self.assertEqual(
- surrogate_spec.botorch_model_class.__name__,
+ surrogate_spec.model_configs[0].botorch_model_class.__name__,
"SingleTaskGP",
)
diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py
index 9d903ad8c9e..27f0ea653a8 100644
--- a/ax/modelbridge/registry.py
+++ b/ax/modelbridge/registry.py
@@ -59,10 +59,8 @@
from ax.models.random.sobol import SobolGenerator
from ax.models.random.uniform import UniformGenerator
from ax.models.torch.botorch import BotorchModel
-from ax.models.torch.botorch_modular.model import (
- BoTorchModel as ModularBoTorchModel,
- SurrogateSpec,
-)
+from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel
+from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.models.torch.cbo_sac import SACBO
from ax.utils.common.kwargs import (
diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py
index 7f77bbde548..20f17d7ce70 100644
--- a/ax/modelbridge/tests/test_registry.py
+++ b/ax/modelbridge/tests/test_registry.py
@@ -28,8 +28,8 @@
from ax.models.random.sobol import SobolGenerator
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
-from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.model import BoTorchModel
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.utils.common.kwargs import get_function_argument_names
from ax.utils.common.testutils import TestCase
@@ -100,7 +100,7 @@ def test_SAASBO(self) -> None:
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
- saasbo.model.surrogate.model_configs[0].botorch_model_class,
+ saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class,
SaasFullyBayesianSingleTaskGP,
)
diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py
index 03804f6f7de..ebd04a2f2e3 100644
--- a/ax/models/torch/botorch_modular/model.py
+++ b/ax/models/torch/botorch_modular/model.py
@@ -10,7 +10,6 @@
import warnings
from collections import OrderedDict
from collections.abc import Mapping, Sequence
-from dataclasses import dataclass, field
from typing import Any
import numpy.typing as npt
@@ -23,12 +22,11 @@
get_rounding_func,
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
check_outcome_dataset_match,
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
- ModelConfig,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
@@ -38,53 +36,10 @@
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.deterministic import FixedSingleSampleModel
-from botorch.models.model import Model
-from botorch.models.transforms.input import InputTransform
-from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
-from gpytorch.kernels.kernel import Kernel
-from gpytorch.likelihoods import Likelihood
-from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
-from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor
-@dataclass(frozen=True)
-class SurrogateSpec:
- """
- Fields in the SurrogateSpec dataclass correspond to arguments in
- ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which
- outcomes the Surrogate is responsible for modeling.
- When ``BotorchModel.fit`` is called, these fields will be used to construct the
- requisite Surrogate objects.
- If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate.
- """
-
- botorch_model_class: type[Model] | None = None
- botorch_model_kwargs: dict[str, Any] = field(default_factory=dict)
-
- mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
- mll_kwargs: dict[str, Any] = field(default_factory=dict)
-
- covar_module_class: type[Kernel] | None = None
- covar_module_kwargs: dict[str, Any] | None = None
-
- likelihood_class: type[Likelihood] | None = None
- likelihood_kwargs: dict[str, Any] | None = None
-
- input_transform_classes: list[type[InputTransform]] | None = None
- input_transform_options: dict[str, dict[str, Any]] | None = None
-
- outcome_transform_classes: list[type[OutcomeTransform]] | None = None
- outcome_transform_options: dict[str, dict[str, Any]] | None = None
-
- allow_batched_models: bool = True
-
- model_configs: list[ModelConfig] = field(default_factory=list)
- metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
- outcomes: list[str] = field(default_factory=list)
-
-
class BoTorchModel(TorchModel, Base):
"""**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
@@ -230,31 +185,17 @@ def fit(
# If a surrogate has not been constructed, construct it.
if self._surrogate is None:
- if (spec := self.surrogate_spec) is not None:
- self._surrogate = Surrogate(
- botorch_model_class=spec.botorch_model_class,
- model_options=spec.botorch_model_kwargs,
- mll_class=spec.mll_class,
- mll_options=spec.mll_kwargs,
- covar_module_class=spec.covar_module_class,
- covar_module_options=spec.covar_module_kwargs,
- likelihood_class=spec.likelihood_class,
- likelihood_options=spec.likelihood_kwargs,
- input_transform_classes=spec.input_transform_classes,
- input_transform_options=spec.input_transform_options,
- outcome_transform_classes=spec.outcome_transform_classes,
- outcome_transform_options=spec.outcome_transform_options,
- model_configs=spec.model_configs,
- metric_to_model_configs=spec.metric_to_model_configs,
- allow_batched_models=spec.allow_batched_models,
- )
+ if self.surrogate_spec is not None:
+ self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec)
else:
self._surrogate = Surrogate()
# Fit the surrogate.
- for config in self.surrogate.model_configs:
+ for config in self.surrogate.surrogate_spec.model_configs:
config.model_options.update(additional_model_inputs)
- for config_list in self.surrogate.metric_to_model_configs.values():
+ for (
+ config_list
+ ) in self.surrogate.surrogate_spec.metric_to_model_configs.values():
for config in config_list:
config.model_options.update(additional_model_inputs)
self.surrogate.fit(
diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py
index 80ecbd2a484..802ed179b70 100644
--- a/ax/models/torch/botorch_modular/surrogate.py
+++ b/ax/models/torch/botorch_modular/surrogate.py
@@ -13,6 +13,7 @@
from collections import OrderedDict
from collections.abc import Sequence
from copy import deepcopy
+from dataclasses import dataclass, field
from logging import Logger
from typing import Any
@@ -271,17 +272,169 @@ def _set_formatted_inputs(
formatted_model_inputs[input_name] = input_class(**input_options)
-def _raise_deprecation_warning(*args: Any, **kwargs: Any) -> None:
+def _raise_deprecation_warning(
+ is_surrogate: bool = False,
+ **kwargs: Any,
+) -> bool:
+ """Raise deprecation warnings for deprecated arguments.
+
+ Args:
+ is_surrogate: A boolean indicating whether the warning is called from
+ Surrogate.
+
+ Returns:
+ A boolean indicating whether any deprecation warnings were raised.
+ """
+ msg = "{k} is deprecated and will be removed in a future version. "
+ if is_surrogate:
+ msg += "Please specify {k} via `surrogate_spec.model_configs`."
+ else:
+ msg += "Please specify {k} via `model_configs`."
+ warnings_raised = False
+ default_is_dict = {"botorch_model_kwargs", "mll_kwargs"}
for k, v in kwargs.items():
- if (v is not None and k != "mll_class") or (
+ should_raise = False
+ if k in default_is_dict:
+ if v != {}:
+ should_raise = True
+ elif (v is not None and k != "mll_class") or (
k == "mll_class" and v is not ExactMarginalLogLikelihood
):
+ should_raise = True
+ if should_raise:
warnings.warn(
- f"{k} is deprecated and will be removed in a future version. "
- f"Please specify {k} via `model_configs`.",
+ msg.format(k=k),
DeprecationWarning,
stacklevel=3,
)
+ warnings_raised = True
+ return warnings_raised
+
+
+def get_model_config_from_deprecated_args(
+ botorch_model_class: type[Model] | None,
+ model_options: dict[str, Any] | None,
+ mll_class: type[MarginalLogLikelihood] | None,
+ mll_options: dict[str, Any] | None,
+ outcome_transform_classes: list[type[OutcomeTransform]] | None,
+ outcome_transform_options: dict[str, dict[str, Any]] | None,
+ input_transform_classes: list[type[InputTransform]] | None,
+ input_transform_options: dict[str, dict[str, Any]] | None,
+ covar_module_class: type[Kernel] | None,
+ covar_module_options: dict[str, Any] | None,
+ likelihood_class: type[Likelihood] | None,
+ likelihood_options: dict[str, Any] | None,
+) -> ModelConfig:
+ """Construct a ModelConfig from deprecated arguments."""
+ model_config_kwargs = {
+ "botorch_model_class": botorch_model_class,
+ "model_options": (model_options or {}).copy(),
+ "mll_class": mll_class,
+ "mll_options": (mll_options or {}).copy(),
+ "outcome_transform_classes": outcome_transform_classes,
+ "outcome_transform_options": outcome_transform_options,
+ "input_transform_classes": input_transform_classes,
+ "input_transform_options": input_transform_options,
+ "covar_module_class": covar_module_class,
+ "covar_module_options": covar_module_options,
+ "likelihood_class": likelihood_class,
+ "likelihood_options": likelihood_options,
+ }
+ model_config_kwargs = {
+ k: v for k, v in model_config_kwargs.items() if v is not None
+ }
+ # pyre-fixme [6]: Incompatible parameter type [6]: In call
+ # `ModelConfig.__init__`, for 1st positional argument, expected
+ # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any],
+ # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]],
+ # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood,
+ # Model]], Type[Likelihood], Type[Kernel]]`.
+ return ModelConfig(**model_config_kwargs)
+
+
+@dataclass(frozen=True)
+class SurrogateSpec:
+ """
+ Fields in the SurrogateSpec dataclass correspond to arguments in
+ ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which
+ outcomes the Surrogate is responsible for modeling.
+ When ``BotorchModel.fit`` is called, these fields will be used to construct the
+ requisite Surrogate objects.
+ If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate.
+ """
+
+ botorch_model_class: type[Model] | None = None
+ botorch_model_kwargs: dict[str, Any] = field(default_factory=dict)
+
+ mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
+ mll_kwargs: dict[str, Any] = field(default_factory=dict)
+
+ covar_module_class: type[Kernel] | None = None
+ covar_module_kwargs: dict[str, Any] | None = None
+
+ likelihood_class: type[Likelihood] | None = None
+ likelihood_kwargs: dict[str, Any] | None = None
+
+ input_transform_classes: list[type[InputTransform]] | None = None
+ input_transform_options: dict[str, dict[str, Any]] | None = None
+
+ outcome_transform_classes: list[type[OutcomeTransform]] | None = None
+ outcome_transform_options: dict[str, dict[str, Any]] | None = None
+
+ allow_batched_models: bool = True
+
+ model_configs: list[ModelConfig] = field(default_factory=list)
+ metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
+ outcomes: list[str] = field(default_factory=list)
+
+ def __post_init__(self) -> None:
+ warnings_raised = _raise_deprecation_warning(
+ is_surrogate=False,
+ botorch_model_class=self.botorch_model_class,
+ botorch_model_kwargs=self.botorch_model_kwargs,
+ mll_class=self.mll_class,
+ mll_kwargs=self.mll_kwargs,
+ outcome_transform_classes=self.outcome_transform_classes,
+ outcome_transform_options=self.outcome_transform_options,
+ input_transform_classes=self.input_transform_classes,
+ input_transform_options=self.input_transform_options,
+ covar_module_class=self.covar_module_class,
+ covar_module_options=self.covar_module_kwargs,
+ likelihood_class=self.likelihood_class,
+ likelihood_options=self.likelihood_kwargs,
+ )
+ if len(self.model_configs) == 0:
+ model_config = get_model_config_from_deprecated_args(
+ botorch_model_class=self.botorch_model_class,
+ model_options=self.botorch_model_kwargs,
+ mll_class=self.mll_class,
+ mll_options=self.mll_kwargs,
+ outcome_transform_classes=self.outcome_transform_classes,
+ outcome_transform_options=self.outcome_transform_options,
+ input_transform_classes=self.input_transform_classes,
+ input_transform_options=self.input_transform_options,
+ covar_module_class=self.covar_module_class,
+ covar_module_options=self.covar_module_kwargs,
+ likelihood_class=self.likelihood_class,
+ likelihood_options=self.likelihood_kwargs,
+ )
+ # re-initialize with the non-deprecated arguments
+ self.__init__(
+ model_configs=[model_config],
+ metric_to_model_configs=self.metric_to_model_configs,
+ allow_batched_models=self.allow_batched_models,
+ outcomes=self.outcomes,
+ )
+ elif warnings_raised:
+ raise UserInputError(
+ "model_configs and deprecated arguments were both specified. "
+ "Please use model_configs and remove deprecated arguments."
+ )
+ if len(self.model_configs) > 1 or any(
+ len(model_config) > 1
+ for model_config in self.metric_to_model_configs.values()
+ ):
+ raise NotImplementedError("Only one model config per metric is supported.")
class Surrogate(Base):
@@ -359,23 +512,23 @@ class string names and the values are dictionaries of input transform
def __init__(
self,
+ surrogate_spec: SurrogateSpec | None = None,
botorch_model_class: type[Model] | None = None,
model_options: dict[str, Any] | None = None,
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
mll_options: dict[str, Any] | None = None,
- outcome_transform_classes: Sequence[type[OutcomeTransform]] | None = None,
+ outcome_transform_classes: list[type[OutcomeTransform]] | None = None,
outcome_transform_options: dict[str, dict[str, Any]] | None = None,
- input_transform_classes: Sequence[type[InputTransform]] | None = None,
+ input_transform_classes: list[type[InputTransform]] | None = None,
input_transform_options: dict[str, dict[str, Any]] | None = None,
covar_module_class: type[Kernel] | None = None,
covar_module_options: dict[str, Any] | None = None,
likelihood_class: type[Likelihood] | None = None,
likelihood_options: dict[str, Any] | None = None,
- model_configs: list[ModelConfig] | None = None,
- metric_to_model_configs: dict[str, list[ModelConfig]] | None = None,
allow_batched_models: bool = True,
) -> None:
- _raise_deprecation_warning(
+ warnings_raised = _raise_deprecation_warning(
+ is_surrogate=True,
botorch_model_class=botorch_model_class,
model_options=model_options,
mll_class=mll_class,
@@ -389,43 +542,34 @@ def __init__(
likelihood_class=likelihood_class,
likelihood_options=likelihood_options,
)
- model_configs = model_configs or []
- if len(model_configs) == 0:
- model_config_kwargs = {
- "botorch_model_class": botorch_model_class,
- "model_options": (model_options or {}).copy(),
- "mll_class": mll_class,
- "mll_options": mll_options,
- "outcome_transform_classes": outcome_transform_classes,
- "outcome_transform_options": outcome_transform_options,
- "input_transform_classes": input_transform_classes,
- "input_transform_options": input_transform_options,
- "covar_module_class": covar_module_class,
- "covar_module_options": covar_module_options,
- "likelihood_class": likelihood_class,
- "likelihood_options": likelihood_options,
- }
- model_config_kwargs = {
- k: v for k, v in model_config_kwargs.items() if v is not None
- }
- # pyre-fixme [6]: Incompatible parameter type [6]: In call
- # `ModelConfig.__init__`, for 1st positional argument, expected
- # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any],
- # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]],
- # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood,
- # Model]], Type[Likelihood], Type[Kernel]]`.
- model_configs = [ModelConfig(**model_config_kwargs)]
- self.model_configs: list[ModelConfig] = model_configs
- self.metric_to_model_configs: dict[str, list[ModelConfig]] = (
- metric_to_model_configs or {}
- )
- if len(self.model_configs) > 1 or any(
- len(model_config) > 1
- for model_config in self.metric_to_model_configs.values()
- ):
- raise NotImplementedError("Only one model config per metric is supported.")
+ # check if surrogate_spec is provided
+ if surrogate_spec is None:
+ # create surrogate spec from deprecated arguments
+ model_config = get_model_config_from_deprecated_args(
+ botorch_model_class=botorch_model_class,
+ model_options=model_options,
+ mll_class=mll_class,
+ mll_options=mll_options,
+ outcome_transform_classes=outcome_transform_classes,
+ outcome_transform_options=outcome_transform_options,
+ input_transform_classes=input_transform_classes,
+ input_transform_options=input_transform_options,
+ covar_module_class=covar_module_class,
+ covar_module_options=covar_module_options,
+ likelihood_class=likelihood_class,
+ likelihood_options=likelihood_options,
+ )
+ surrogate_spec = SurrogateSpec(
+ model_configs=[model_config], allow_batched_models=allow_batched_models
+ )
+
+ elif warnings_raised:
+ raise UserInputError(
+ "model_configs and deprecated arguments were both specified. "
+ "Please use model_configs and remove deprecated arguments."
+ )
- self.allow_batched_models = allow_batched_models
+ self.surrogate_spec: SurrogateSpec = surrogate_spec
# Store the last dataset used to fit the model for a given metric(s).
# If the new dataset is identical, we will skip model fitting for that metric.
# The keys are `tuple(dataset.outcome_names)`.
@@ -445,11 +589,7 @@ def __init__(
self._model: Model | None = None
def __repr__(self) -> str:
- return (
- f"<{self.__class__.__name__}"
- f" model_configs={self.model_configs},"
- f" metric_to_model_configs={self.metric_to_model_configs}>"
- )
+ return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>"
@property
def model(self) -> Model:
@@ -614,20 +754,22 @@ def fit(
# the batched multi-output case, so we first see which model would be chosen
# given the Yvars and the properties of data.
if (
- len(self.model_configs) == 1
- and self.model_configs[0].botorch_model_class is None
+ len(self.surrogate_spec.model_configs) == 1
+ and self.surrogate_spec.model_configs[0].botorch_model_class is None
):
default_botorch_model_class = choose_model_class(
datasets=datasets, search_space_digest=search_space_digest
)
else:
- default_botorch_model_class = self.model_configs[0].botorch_model_class
+ default_botorch_model_class = self.surrogate_spec.model_configs[
+ 0
+ ].botorch_model_class
should_use_model_list = use_model_list(
datasets=datasets,
botorch_model_class=not_none(default_botorch_model_class),
- model_configs=self.model_configs,
- allow_batched_models=self.allow_batched_models,
- metric_to_model_configs=self.metric_to_model_configs,
+ model_configs=self.surrogate_spec.model_configs,
+ allow_batched_models=self.surrogate_spec.allow_batched_models,
+ metric_to_model_configs=self.surrogate_spec.metric_to_model_configs,
)
if not should_use_model_list and len(datasets) > 1:
@@ -646,7 +788,7 @@ def fit(
else:
submodel_state_dict = state_dict
model_config = None
- if len(self.metric_to_model_configs) > 0:
+ if len(self.surrogate_spec.metric_to_model_configs) > 0:
# if metric_to_model_configs is not empty, then
# we are using a model list and each dataset
# should have only one outcome.
@@ -655,7 +797,7 @@ def fit(
"Each dataset should have only one outcome when "
"metric_to_model_configs is specified."
)
- model_config_list = self.metric_to_model_configs.get(
+ model_config_list = self.surrogate_spec.metric_to_model_configs.get(
dataset.outcome_names[0]
)
@@ -663,7 +805,7 @@ def fit(
if model_config_list is not None:
model_config = model_config_list[0]
if model_config is None:
- model_config = self.model_configs[0]
+ model_config = self.surrogate_spec.model_configs[0]
model = self._construct_model(
dataset=dataset,
search_space_digest=search_space_digest,
@@ -828,10 +970,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:
"""Serialize attributes of this surrogate, to be passed back to it
as kwargs on reinstantiation.
"""
- return {
- "model_configs": self.model_configs,
- "metric_to_model_configs": self.metric_to_model_configs,
- }
+ return {"surrogate_spec": self.surrogate_spec}
def _extract_construct_input_transform_args(
self, model_config: ModelConfig, search_space_digest: SearchSpaceDigest
diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py
index 389b3072175..797719a0f5c 100644
--- a/ax/models/torch/tests/test_model.py
+++ b/ax/models/torch/tests/test_model.py
@@ -21,9 +21,8 @@
from ax.models.torch.botorch_modular.model import (
BoTorchModel,
choose_botorch_acqf_class,
- SurrogateSpec,
)
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
choose_model_class,
construct_acquisition_and_optimizer_options,
@@ -719,32 +718,17 @@ def test_evaluate_acquisition_function(
def test_surrogate_model_options_propagation(
self, _m1: Mock, _m2: Mock, mock_init: Mock
) -> None:
- model = BoTorchModel(
- surrogate_spec=SurrogateSpec(
- botorch_model_kwargs={"some_option": "some_value"}
- )
+ surrogate_spec = SurrogateSpec(
+ botorch_model_kwargs={"some_option": "some_value"}
)
+ model = BoTorchModel(surrogate_spec=surrogate_spec)
model.fit(
datasets=self.non_block_design_training_data,
search_space_digest=self.mf_search_space_digest,
candidate_metadata=self.candidate_metadata,
)
mock_init.assert_called_with(
- botorch_model_class=None,
- model_options={"some_option": "some_value"},
- mll_class=ExactMarginalLogLikelihood,
- mll_options={},
- covar_module_class=None,
- covar_module_options=None,
- likelihood_class=None,
- likelihood_options=None,
- input_transform_classes=None,
- input_transform_options=None,
- outcome_transform_classes=None,
- outcome_transform_options=None,
- model_configs=[],
- metric_to_model_configs={},
- allow_batched_models=True,
+ surrogate_spec=surrogate_spec,
)
@mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate)
@@ -753,28 +737,15 @@ def test_surrogate_model_options_propagation(
def test_surrogate_options_propagation(
self, _m1: Mock, _m2: Mock, mock_init: Mock
) -> None:
- model = BoTorchModel(surrogate_spec=SurrogateSpec(allow_batched_models=False))
+ surrogate_spec = SurrogateSpec(allow_batched_models=False)
+ model = BoTorchModel(surrogate_spec=surrogate_spec)
model.fit(
datasets=self.non_block_design_training_data,
search_space_digest=self.mf_search_space_digest,
candidate_metadata=self.candidate_metadata,
)
mock_init.assert_called_with(
- botorch_model_class=None,
- model_options={},
- mll_class=ExactMarginalLogLikelihood,
- mll_options={},
- covar_module_class=None,
- covar_module_options=None,
- likelihood_class=None,
- likelihood_options=None,
- input_transform_classes=None,
- input_transform_options=None,
- outcome_transform_classes=None,
- outcome_transform_options=None,
- model_configs=[],
- metric_to_model_configs={},
- allow_batched_models=False,
+ surrogate_spec=surrogate_spec,
)
@mock.patch(
diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py
index 6bbd600df90..eadce0951bf 100644
--- a/ax/models/torch/tests/test_surrogate.py
+++ b/ax/models/torch/tests/test_surrogate.py
@@ -19,7 +19,11 @@
from ax.exceptions.core import UserInputError
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
-from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate
+from ax.models.torch.botorch_modular.surrogate import (
+ _extract_model_kwargs,
+ Surrogate,
+ SurrogateSpec,
+)
from ax.models.torch.botorch_modular.utils import (
choose_model_class,
fit_botorch_model,
@@ -39,7 +43,7 @@
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.transforms.input import InputPerturbation, Normalize
-from botorch.models.transforms.outcome import Standardize
+from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import SupervisedDataset
from gpytorch.constraints import GreaterThan, Interval
from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel
@@ -202,7 +206,7 @@ def _get_surrogate(
mll_options = None
if use_outcome_transform:
- outcome_transform_classes = [Standardize]
+ outcome_transform_classes: list[type[OutcomeTransform]] = [Standardize]
outcome_transform_options = {"Standardize": {"m": 1}}
else:
outcome_transform_classes = None
@@ -222,10 +226,15 @@ def test_init(self) -> None:
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class)
self.assertEqual(
- surrogate.model_configs[0].botorch_model_class, botorch_model_class
+ surrogate.surrogate_spec.model_configs[0].botorch_model_class,
+ botorch_model_class,
+ )
+ self.assertEqual(
+ surrogate.surrogate_spec.model_configs[0].mll_class, self.mll_class
)
- self.assertEqual(surrogate.model_configs[0].mll_class, self.mll_class)
- self.assertTrue(surrogate.allow_batched_models) # True by default
+ self.assertTrue(
+ surrogate.surrogate_spec.allow_batched_models
+ ) # True by default
def test_clone_reset(self) -> None:
surrogate = self._get_surrogate(botorch_model_class=SingleTaskGP)[0]
@@ -454,7 +463,7 @@ def test_construct_model(self) -> None:
model = surrogate._construct_model(
dataset=self.training_data[0],
search_space_digest=self.search_space_digest,
- model_config=surrogate.model_configs[0],
+ model_config=surrogate.surrogate_spec.model_configs[0],
default_botorch_model_class=botorch_model_class,
state_dict=None,
refit=True,
@@ -480,7 +489,7 @@ def test_construct_model(self) -> None:
new_model = surrogate._construct_model(
dataset=self.training_data[0],
search_space_digest=self.search_space_digest,
- model_config=surrogate.model_configs[0],
+ model_config=surrogate.surrogate_spec.model_configs[0],
default_botorch_model_class=botorch_model_class,
state_dict=None,
refit=True,
@@ -517,7 +526,11 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None:
"likelihood_class": FixedNoiseGaussianLikelihood,
}
if use_model_config:
- surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)])
+ surrogate = Surrogate(
+ surrogate_spec=SurrogateSpec(
+ model_configs=[ModelConfig(**model_config_kwargs)]
+ )
+ )
else:
surrogate = Surrogate(**model_config_kwargs)
with self.assertRaisesRegex(UserInputError, "does not support"):
@@ -536,7 +549,11 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None:
"likelihood_options": {"noise_constraint": noise_constraint},
}
if use_model_config:
- surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)])
+ surrogate = Surrogate(
+ surrogate_spec=SurrogateSpec(
+ model_configs=[ModelConfig(**model_config_kwargs)]
+ )
+ )
else:
surrogate = Surrogate(**model_config_kwargs)
surrogate.fit(
@@ -552,7 +569,8 @@ def test_construct_custom_model(self, use_model_config: bool = False) -> None:
noise_constraint.__dict__,
)
self.assertEqual(
- surrogate.model_configs[0].mll_class, LeaveOneOutPseudoLikelihood
+ surrogate.surrogate_spec.model_configs[0].mll_class,
+ LeaveOneOutPseudoLikelihood,
)
self.assertEqual(type(model.covar_module), RBFKernel)
self.assertEqual(model.covar_module.ard_num_dims, 3)
@@ -562,11 +580,13 @@ def test_construct_custom_model_with_config(self) -> None:
def test_construct_model_with_metric_to_model_configs(self) -> None:
surrogate = Surrogate(
- metric_to_model_configs={
- "metric": [ModelConfig()],
- "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)],
- },
- model_configs=[ModelConfig(covar_module_class=LinearKernel)],
+ surrogate_spec=SurrogateSpec(
+ metric_to_model_configs={
+ "metric": [ModelConfig()],
+ "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)],
+ },
+ model_configs=[ModelConfig(covar_module_class=LinearKernel)],
+ )
)
training_data = self.training_data + [
SupervisedDataset(
@@ -600,13 +620,13 @@ def test_multiple_model_configs_error(self) -> None:
with self.assertRaisesRegex(
NotImplementedError, "Only one model config per metric is supported."
):
- Surrogate(
+ SurrogateSpec(
model_configs=[ModelConfig(), ModelConfig()],
)
with self.assertRaisesRegex(
NotImplementedError, "Only one model config per metric is supported."
):
- Surrogate(
+ SurrogateSpec(
metric_to_model_configs={"metric": [ModelConfig(), ModelConfig()]},
)
@@ -720,10 +740,7 @@ def test_best_out_of_sample_point(self) -> None:
def test_serialize_attributes_as_kwargs(self) -> None:
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class)
- expected = {
- "model_configs": surrogate.model_configs,
- "metric_to_model_configs": surrogate.metric_to_model_configs,
- }
+ expected = {"surrogate_spec": surrogate.surrogate_spec}
self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)
@mock_botorch_optimize
@@ -752,7 +769,7 @@ def test_w_robust_digest(self) -> None:
environmental_variables=[],
multiplicative=False,
)
- surrogate.model_configs[0].input_transform_classes = [Normalize]
+ surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize]
with self.assertRaisesRegex(NotImplementedError, "input transforms"):
surrogate.fit(
datasets=self.training_data,
@@ -905,7 +922,7 @@ def setUp(self) -> None:
)
def test_init(self) -> None:
- model_config = self.surrogate.model_configs[0]
+ model_config = self.surrogate.surrogate_spec.model_configs[0]
self.assertEqual(
[model_config.botorch_model_class] * 2,
[*self.botorch_submodel_class_per_outcome.values()],
@@ -925,7 +942,9 @@ def test_init(self) -> None:
def test_construct_per_outcome_options(
self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock
) -> None:
- self.surrogate.model_configs[0].model_options.update({"output_tasks": [2]})
+ self.surrogate.surrogate_spec.model_configs[0].model_options.update(
+ {"output_tasks": [2]}
+ )
for fixed_noise in (False, True):
mock_fit.reset_mock()
mock_MTGP_construct_inputs.reset_mock()
@@ -1021,7 +1040,7 @@ def test_fit(
# pyre-ignore[6]: Incompatible parameter type: In call
# `issubclass`, for 1st positional argument, expected
# `Type[typing.Any]` but got `Optional[Type[Model]]`.
- surrogate.model_configs[0].botorch_model_class,
+ surrogate.surrogate_spec.model_configs[0].botorch_model_class,
MultiTaskGP,
)
search_space_digest = (
@@ -1206,7 +1225,8 @@ def test_construct_custom_model(self) -> None:
models = checked_cast(ModelListGP, surrogate._model).models
self.assertEqual(len(models), 2)
self.assertEqual(
- surrogate.model_configs[0].mll_class, ExactMarginalLogLikelihood
+ surrogate.surrogate_spec.model_configs[0].mll_class,
+ ExactMarginalLogLikelihood,
)
# Make sure we properly copied the transforms.
self.assertNotEqual(
@@ -1260,7 +1280,7 @@ def test_w_robust_digest(self) -> None:
environmental_variables=[],
multiplicative=False,
)
- surrogate.model_configs[0].input_transform_classes = [Normalize]
+ surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize]
with self.assertRaisesRegex(NotImplementedError, "input transforms"):
surrogate.fit(
datasets=self.supervised_training_data,
diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py
index 5b31efb70f8..140ad1b33bd 100644
--- a/ax/storage/json_store/decoder.py
+++ b/ax/storage/json_store/decoder.py
@@ -44,8 +44,7 @@
TransitionCriterion,
TrialBasedCriterion,
)
-from ax.models.torch.botorch_modular.model import SurrogateSpec
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import ModelConfig
from ax.storage.json_store.decoders import (
batch_trial_from_json,
diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py
index 02498d03669..3f0aad31e66 100644
--- a/ax/storage/json_store/registry.py
+++ b/ax/storage/json_store/registry.py
@@ -96,8 +96,8 @@
TransitionCriterion,
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
-from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.model import BoTorchModel
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import ModelConfig
from ax.models.winsorization_config import WinsorizationConfig
from ax.runners.synthetic import SyntheticRunner
diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py
index 24c8a634e57..a2e17f95acb 100644
--- a/ax/storage/json_store/tests/test_json_store.py
+++ b/ax/storage/json_store/tests/test_json_store.py
@@ -600,11 +600,11 @@ def test_encode_decode_surrogate_spec(self) -> None:
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
- org_as_dict = dataclasses.asdict(org_object)
- converted_as_dict = dataclasses.asdict(converted_object)
+ org_as_dict = dataclasses.asdict(org_object)["model_configs"][0]
+ converted_as_dict = dataclasses.asdict(converted_object)["model_configs"][0]
# Covar module kwargs will fail comparison. Manually compare.
- org_covar_kwargs = org_as_dict.pop("covar_module_kwargs")
- converted_covar_kwargs = converted_as_dict.pop("covar_module_kwargs")
+ org_covar_kwargs = org_as_dict.pop("covar_module_options")
+ converted_covar_kwargs = converted_as_dict.pop("covar_module_options")
self.assertEqual(org_covar_kwargs.keys(), converted_covar_kwargs.keys())
for k in org_covar_kwargs:
org_ = org_covar_kwargs[k]
diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py
index 6aaf96ab9de..8e130319fe2 100644
--- a/ax/utils/testing/benchmark_stubs.py
+++ b/ax/utils/testing/benchmark_stubs.py
@@ -28,7 +28,7 @@
from ax.modelbridge.registry import Models
from ax.modelbridge.torch import TorchModelBridge
from ax.models.torch.botorch_modular.model import BoTorchModel
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.service.scheduler import SchedulerOptions
from ax.utils.common.constants import Keys
from ax.utils.testing.core_stubs import (
@@ -189,7 +189,11 @@ def get_sobol_gpei_benchmark_method() -> BenchmarkMethod:
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
- "surrogate": Surrogate(SingleTaskGP),
+ "surrogate": Surrogate(
+ surrogate_spec=SurrogateSpec(
+ botorch_model_class=SingleTaskGP
+ )
+ ),
# TODO: tests should better reflect defaults and not
# re-implement this logic.
"botorch_acqf_class": qNoisyExpectedImprovement,
diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py
index 7d90a339df1..007375df242 100644
--- a/ax/utils/testing/core_stubs.py
+++ b/ax/utils/testing/core_stubs.py
@@ -99,9 +99,9 @@
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
-from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
+from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.sebo import SEBOAcquisition
-from ax.models.torch.botorch_modular.surrogate import Surrogate
+from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.winsorization_config import WinsorizationConfig
from ax.runners.synthetic import SyntheticRunner
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb
index da21088c63c..dc8f72147ec 100644
--- a/tutorials/modular_botax.ipynb
+++ b/tutorials/modular_botax.ipynb
@@ -24,19 +24,20 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730904956474,
- "executionStopTime": 1730904963470,
- "id": "about-preview",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a"
},
- "originalKey": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7",
+ "id": "about-preview",
+ "originalKey": "f4e8ae18-2aa3-4943-a15a-29851889445c",
"outputsInitialized": true,
- "requestMsgId": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7",
- "serverExecutionDuration": 4351.2808320811
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916291451,
+ "executionStopTime": 1730916298337,
+ "serverExecutionDuration": 4531.2523420434,
+ "collapsed": false,
+ "requestMsgId": "f4e8ae18-2aa3-4943-a15a-29851889445c",
+ "customOutput": null
},
"source": [
"from typing import Any, Dict, Optional, Tuple, Type\n",
@@ -48,7 +49,8 @@
"\n",
"# Ax wrappers for BoTorch components\n",
"from ax.models.torch.botorch_modular.model import BoTorchModel\n",
- "from ax.models.torch.botorch_modular.surrogate import Surrogate\n",
+ "from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec\n",
+ "from ax.models.torch.botorch_modular.utils import ModelConfig\n",
"\n",
"# Experiment examination utilities\n",
"from ax.service.utils.report_utils import exp_to_df\n",
@@ -76,14 +78,14 @@
"output_type": "stream",
"name": "stderr",
"text": [
- "I1106 065557.352 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n"
+ "I1106 100452.333 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
- "I1106 065557.353 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n"
+ "I1106 100452.334 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n"
]
}
]
@@ -91,15 +93,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "northern-affairs",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a"
},
+ "id": "northern-affairs",
"originalKey": "c9a665ca-497e-4d7c-bbb5-1b9f8d1d311c",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"# Setup and Usage of BoTorch Models in Ax\n",
@@ -125,15 +127,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "pending-support",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a"
},
+ "id": "pending-support",
"originalKey": "4706d02e-6b3f-4161-9e08-f5a31328b1d1",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 1. Quick-start example\n",
@@ -144,19 +146,20 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730904958719,
- "executionStopTime": 1730904963489,
- "id": "parental-sending",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8"
},
- "originalKey": "146a9a1b-52e6-4d76-9fc5-79025b392673",
+ "id": "parental-sending",
+ "originalKey": "20f25ded-5aae-47ee-955e-a2d5a2a1fe09",
"outputsInitialized": true,
- "requestMsgId": "146a9a1b-52e6-4d76-9fc5-79025b392673",
- "serverExecutionDuration": 42.191333021037
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916294801,
+ "executionStopTime": 1730916298389,
+ "serverExecutionDuration": 22.605526028201,
+ "collapsed": false,
+ "requestMsgId": "20f25ded-5aae-47ee-955e-a2d5a2a1fe09",
+ "customOutput": null
},
"source": [
"experiment = get_branin_experiment(with_trial=True)\n",
@@ -168,7 +171,7 @@
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:56:01] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n"
+ "[INFO 11-06 10:04:56] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n"
]
}
]
@@ -176,19 +179,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730904959343,
- "executionStopTime": 1730904964892,
- "id": "rough-somerset",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd"
},
- "originalKey": "aa532754-01ad-4441-84c1-2ac7f54ecf1e",
+ "id": "rough-somerset",
+ "originalKey": "c0806cce-a1d3-41b8-96fc-678aa3c9dd92",
"outputsInitialized": true,
- "requestMsgId": "aa532754-01ad-4441-84c1-2ac7f54ecf1e",
- "serverExecutionDuration": 870.78339292202
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916295849,
+ "executionStopTime": 1730916299900,
+ "serverExecutionDuration": 852.73489891551,
+ "collapsed": false,
+ "requestMsgId": "c0806cce-a1d3-41b8-96fc-678aa3c9dd92"
},
"source": [
"# `Models` automatically selects a model + model bridge combination.\n",
@@ -196,7 +199,9 @@
"model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n",
" experiment=experiment,\n",
" data=data,\n",
- " surrogate=Surrogate(SingleTaskGP), # Optional, will use default if unspecified\n",
+ " surrogate_spec=SurrogateSpec(\n",
+ " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n",
+ " ), # Optional, will use default if unspecified\n",
" botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified\n",
")"
],
@@ -206,7 +211,7 @@
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:56:01] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n"
+ "[INFO 11-06 10:04:57] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n"
]
}
]
@@ -214,15 +219,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "hairy-wiring",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8"
},
+ "id": "hairy-wiring",
"originalKey": "46f5c2c7-400d-4d8d-b0b9-a241657b173f",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)."
@@ -231,19 +236,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730904961333,
- "executionStopTime": 1730904964907,
- "id": "consecutive-summary",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7"
},
- "originalKey": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696",
+ "id": "consecutive-summary",
+ "originalKey": "f64e9d2e-bfd4-47da-8292-dbe7e70cbe1f",
"outputsInitialized": true,
- "requestMsgId": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696",
- "serverExecutionDuration": 284.31268292479
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916299852,
+ "executionStopTime": 1730916300305,
+ "serverExecutionDuration": 233.20194100961,
+ "collapsed": false,
+ "requestMsgId": "f64e9d2e-bfd4-47da-8292-dbe7e70cbe1f"
},
"source": [
"generator_run = model_bridge_with_GPEI.gen(n=1)\n",
@@ -254,7 +259,7 @@
{
"output_type": "execute_result",
"data": {
- "text/plain": "Arm(parameters={'x1': -5.0, 'x2': 0.0})"
+ "text/plain": "Arm(parameters={'x1': 10.0, 'x2': 15.0})"
},
"metadata": {},
"execution_count": 4
@@ -264,15 +269,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "diverse-richards",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626"
},
+ "id": "diverse-richards",
"originalKey": "804bac30-db07-4444-98a2-7a5f05007495",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"-----\n",
@@ -285,34 +290,34 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "grand-committee",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9"
},
+ "id": "grand-committee",
"originalKey": "31b54ce5-2590-4617-b10c-d24ed3cce51d",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 2. BoTorchModel = Surrogate + Acquisition\n",
"\n",
- "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class."
+ "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The Surrogate is defined by a `SurrogateSpec`. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class."
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "thousand-blanket",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d"
},
+ "id": "thousand-blanket",
"originalKey": "4a4e006e-07fa-4d63-8b9a-31b67075e40e",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"### 2A. Example that uses defaults and requires no options\n",
@@ -323,23 +328,21 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905022012,
- "executionStopTime": 1730905022269,
- "id": "changing-xerox",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c"
},
- "originalKey": "509e30d7-dc32-4190-836f-f221cacbff31",
+ "id": "changing-xerox",
+ "originalKey": "fa86552a-0b80-4040-a0c4-61a0de37bdc1",
"outputsInitialized": true,
- "requestMsgId": "509e30d7-dc32-4190-836f-f221cacbff31",
- "serverExecutionDuration": 1.972567057237
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916302730,
+ "executionStopTime": 1730916304031,
+ "serverExecutionDuration": 1.7747740494087,
+ "collapsed": false,
+ "requestMsgId": "fa86552a-0b80-4040-a0c4-61a0de37bdc1"
},
"source": [
- "from ax.models.torch.botorch_modular.utils import ModelConfig\n",
- "\n",
"# The surrogate is not specified, so it will be auto-selected\n",
"# during `model.fit`.\n",
"GPEI_model = BoTorchModel(botorch_acqf_class=qLogExpectedImprovement)\n",
@@ -347,27 +350,29 @@
"# The acquisition class is not specified, so it will be\n",
"# auto-selected during `model.gen` or `model.evaluate_acquisition`\n",
"GPEI_model = BoTorchModel(\n",
- " surrogate=Surrogate(model_configs=[ModelConfig(SingleTaskGP)])\n",
+ " surrogate_spec=SurrogateSpec(\n",
+ " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n",
+ " )\n",
")\n",
"\n",
"# Both the surrogate and acquisition class will be auto-selected.\n",
"GPEI_model = BoTorchModel()"
],
- "execution_count": 7,
+ "execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
- "id": "lovely-mechanics",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38"
},
+ "id": "lovely-mechanics",
"originalKey": "7b9fae38-fe5d-4e5b-8b5f-2953c1ef09d2",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"### 2B. Example with all the options\n",
@@ -377,24 +382,24 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905023369,
- "executionStopTime": 1730905023622,
- "id": "twenty-greek",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a"
},
- "originalKey": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03",
+ "id": "twenty-greek",
+ "originalKey": "8d824e37-b087-4bab-9b16-4354e9509df7",
"outputsInitialized": true,
- "requestMsgId": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03",
- "serverExecutionDuration": 1.9794970285147
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916305930,
+ "executionStopTime": 1730916306168,
+ "serverExecutionDuration": 2.6916969800368,
+ "collapsed": false,
+ "requestMsgId": "8d824e37-b087-4bab-9b16-4354e9509df7"
},
"source": [
"model = BoTorchModel(\n",
" # Optional `Surrogate` specification to use instead of default\n",
- " surrogate=Surrogate(\n",
+ " surrogate_spec=SurrogateSpec(\n",
" model_configs=[\n",
" ModelConfig(\n",
" # BoTorch `Model` type\n",
@@ -421,21 +426,21 @@
" warm_start_refit=True,\n",
")"
],
- "execution_count": 8,
+ "execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
- "id": "fourth-material",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e"
},
+ "id": "fourth-material",
"originalKey": "7140bb19-09b4-4abe-951d-53902ae07833",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 2C. `Surrogate` and `Acquisition` Q&A\n",
@@ -450,15 +455,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "violent-course",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b"
},
+ "id": "violent-course",
"originalKey": "71f92895-874d-4fc7-ae87-a5519b18d1a0",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?"
@@ -467,18 +472,18 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "unlike-football",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298",
"showInput": false
},
+ "id": "unlike-football",
"originalKey": "4af8afa2-5056-46be-b7b9-428127e668cc",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"### 3a. Making a `Surrogate` from BoTorch `Model`:\n",
@@ -486,29 +491,30 @@
"\n",
"If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:\n",
"1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/main/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/main/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept training data in form of a `SupervisedDataset` container and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.\n",
- "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`."
+ "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via the `model_options` argument to `ModelConfig` in `SurrogateSpec`."
]
},
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905050779,
- "executionStopTime": 1730905051022,
- "id": "dynamic-university",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53"
},
- "originalKey": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7",
+ "id": "dynamic-university",
+ "originalKey": "746fc2a3-0e0e-4ab4-84d9-32434eb1fc34",
"outputsInitialized": true,
- "requestMsgId": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7",
- "serverExecutionDuration": 2.5736939860508
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916308518,
+ "executionStopTime": 1730916308769,
+ "serverExecutionDuration": 2.4644429795444,
+ "collapsed": false,
+ "requestMsgId": "746fc2a3-0e0e-4ab4-84d9-32434eb1fc34"
},
"source": [
+ "from botorch.models.model import Model\n",
"from botorch.utils.datasets import SupervisedDataset\n",
"\n",
"\n",
@@ -530,31 +536,31 @@
" }\n",
"\n",
"\n",
- "surrogate = Surrogate(\n",
+ "surrogate_spec = SurrogateSpec(\n",
" model_configs=[\n",
" ModelConfig(\n",
" botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n",
" # Optional dict of additional keyword arguments to `MyModelClass`\n",
" model_options={},\n",
" )\n",
- " ],\n",
+ " ]\n",
")"
],
- "execution_count": 9,
+ "execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
- "id": "otherwise-context",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c"
},
+ "id": "otherwise-context",
"originalKey": "5a27fd2c-4c4c-41fe-a634-f6d0ec4f1666",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial."
@@ -563,15 +569,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "northern-invite",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9"
},
+ "id": "northern-invite",
"originalKey": "df06d02b-95cb-4d34-aac6-773231f1a129",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax"
@@ -580,47 +586,49 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "surrounded-denial",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f",
"showInput": false
},
+ "id": "surrounded-denial",
"originalKey": "d4861847-b757-4fcd-9f35-ba258080812c",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"Steps to set up any `AcquisitionFunction` in Ax are:\n",
- "1. Define an input constructor function. The purpose of this method is to produce arguments to an acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n",
+ "1. Define an input constructor function. The purpose of this method is to produce arguments to a acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n",
" 1. Note that the new input constructor needs to be decorated with `@acqf_input_constructor(AcquisitionFunctionClass)` to register it.\n",
- "2. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n",
- "3. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`."
+ "3. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n",
+ "4. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`."
]
},
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905067704,
- "executionStopTime": 1730905067939,
- "id": "interested-search",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"code_folding": [],
"hidden_ranges": [],
"originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a"
},
- "originalKey": "075243f0-c2d6-46fd-9a18-5b77af258abf",
+ "id": "interested-search",
+ "originalKey": "f188f40b-64ba-4b0c-b216-f3dea8c7465e",
"outputsInitialized": true,
- "requestMsgId": "075243f0-c2d6-46fd-9a18-5b77af258abf",
- "serverExecutionDuration": 5.0081580411643
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916310518,
+ "executionStopTime": 1730916310772,
+ "serverExecutionDuration": 4.9752569757402,
+ "collapsed": false,
+ "requestMsgId": "f188f40b-64ba-4b0c-b216-f3dea8c7465e",
+ "customOutput": null
},
"source": [
+ "from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse\n",
"from botorch.acquisition.acquisition import AcquisitionFunction\n",
"from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict\n",
"from botorch.utils.datasets import SupervisedDataset\n",
@@ -642,6 +650,7 @@
" pass\n",
"\n",
"\n",
+ "\n",
"# 2-3. Specifying `botorch_acqf_class` and `acquisition_options`\n",
"BoTorchModel(\n",
" botorch_acqf_class=MyAcquisitionFunctionClass,\n",
@@ -655,7 +664,7 @@
" },\n",
")"
],
- "execution_count": 10,
+ "execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
@@ -663,22 +672,22 @@
"text/plain": "BoTorchModel"
},
"metadata": {},
- "execution_count": 10
+ "execution_count": 8
}
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "metallic-imaging",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0"
},
+ "id": "metallic-imaging",
"originalKey": "b057722d-b8ca-47dd-b2c8-1ff4a71c4863",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example."
@@ -687,15 +696,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "descending-australian",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808"
},
+ "id": "descending-australian",
"originalKey": "a7406f13-1468-487d-ac5e-7d2a45394850",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 4. Using `Models.BOTORCH_MODULAR` \n",
@@ -708,19 +717,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905070804,
- "executionStopTime": 1730905071313,
- "id": "attached-border",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb"
},
- "originalKey": "980d8513-8607-4099-8cfe-7c8d7bf5afe9",
+ "id": "attached-border",
+ "originalKey": "052cf2e4-8de0-4ec3-a3f9-478194b10928",
"outputsInitialized": true,
- "requestMsgId": "980d8513-8607-4099-8cfe-7c8d7bf5afe9",
- "serverExecutionDuration": 262.54303695168
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916311983,
+ "executionStopTime": 1730916312395,
+ "serverExecutionDuration": 202.78578903526,
+ "collapsed": false,
+ "requestMsgId": "052cf2e4-8de0-4ec3-a3f9-478194b10928"
},
"source": [
"model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n",
@@ -729,13 +738,13 @@
")\n",
"model_bridge_with_GPEI.gen(1)"
],
- "execution_count": 11,
+ "execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:57:51] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n"
+ "[INFO 11-06 10:05:12] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n"
]
},
{
@@ -744,31 +753,31 @@
"text/plain": "GeneratorRun(1 arms, total weight 1.0)"
},
"metadata": {},
- "execution_count": 11
+ "execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905071744,
- "executionStopTime": 1730905071981,
- "id": "powerful-gamma",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "89930a31-e058-434b-b587-181931e247b6"
},
- "originalKey": "6ec047f0-a75e-4733-b4d7-20045627f0b2",
+ "id": "powerful-gamma",
+ "originalKey": "b7f924fe-f3d9-4211-b402-421f4c90afe5",
"outputsInitialized": true,
- "requestMsgId": "6ec047f0-a75e-4733-b4d7-20045627f0b2",
- "serverExecutionDuration": 2.8772989753634
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916312432,
+ "executionStopTime": 1730916312657,
+ "serverExecutionDuration": 3.1334219966084,
+ "collapsed": false,
+ "requestMsgId": "b7f924fe-f3d9-4211-b402-421f4c90afe5"
},
"source": [
"model_bridge_with_GPEI.model.botorch_acqf_class"
],
- "execution_count": 12,
+ "execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
@@ -776,31 +785,31 @@
"text/plain": "botorch.acquisition.logei.qLogNoisyExpectedImprovement"
},
"metadata": {},
- "execution_count": 12
+ "execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905076319,
- "executionStopTime": 1730905076580,
- "id": "improved-replication",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01"
},
- "originalKey": "68caa729-792d-4692-bc4d-4c0b8d03e022",
+ "id": "improved-replication",
+ "originalKey": "942f1817-8d40-48f8-8725-90c25a079e4c",
"outputsInitialized": true,
- "requestMsgId": "68caa729-792d-4692-bc4d-4c0b8d03e022",
- "serverExecutionDuration": 3.2039729412645
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916312847,
+ "executionStopTime": 1730916313093,
+ "serverExecutionDuration": 3.410067060031,
+ "collapsed": false,
+ "requestMsgId": "942f1817-8d40-48f8-8725-90c25a079e4c"
},
"source": [
- "type(model_bridge_with_GPEI.model.surrogate.model)"
+ "model_bridge_with_GPEI.model.surrogate.model.__class__"
],
- "execution_count": 13,
+ "execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
@@ -808,22 +817,22 @@
"text/plain": "botorch.models.gp_regression.SingleTaskGP"
},
"metadata": {},
- "execution_count": 13
+ "execution_count": 11
}
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "connected-sheet",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262"
},
+ "id": "connected-sheet",
"originalKey": "f5c0adbd-00a6-428d-810f-1e7ed0954b08",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:"
@@ -832,19 +841,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905078786,
- "executionStopTime": 1730905079551,
- "id": "documentary-jurisdiction",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5"
},
- "originalKey": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be",
+ "id": "documentary-jurisdiction",
+ "originalKey": "9c64c497-f663-42a6-aa48-1f1f2ae2b80b",
"outputsInitialized": true,
- "requestMsgId": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be",
- "serverExecutionDuration": 512.32317101676
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916314009,
+ "executionStopTime": 1730916314736,
+ "serverExecutionDuration": 518.53136904538,
+ "collapsed": false,
+ "requestMsgId": "9c64c497-f663-42a6-aa48-1f1f2ae2b80b"
},
"source": [
"model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n",
@@ -855,27 +864,27 @@
")\n",
"model_bridge_with_EHVI.gen(1)"
],
- "execution_count": 14,
+ "execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:57:59] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n"
+ "[INFO 11-06 10:05:14] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n"
+ "[INFO 11-06 10:05:14] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n"
+ "[INFO 11-06 10:05:14] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n"
]
},
{
@@ -884,31 +893,31 @@
"text/plain": "GeneratorRun(1 arms, total weight 1.0)"
},
"metadata": {},
- "execution_count": 14
+ "execution_count": 12
}
]
},
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905079324,
- "executionStopTime": 1730905079651,
- "id": "changed-maintenance",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566"
},
- "originalKey": "87247bf1-04f2-4a8a-92f6-18174d70cbb7",
+ "id": "changed-maintenance",
+ "originalKey": "ab6e84ac-2a55-4f48-9ab7-06b8d9b58d1f",
"outputsInitialized": true,
- "requestMsgId": "87247bf1-04f2-4a8a-92f6-18174d70cbb7",
- "serverExecutionDuration": 3.0141109600663
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916314586,
+ "executionStopTime": 1730916314842,
+ "serverExecutionDuration": 3.3097150735557,
+ "collapsed": false,
+ "requestMsgId": "ab6e84ac-2a55-4f48-9ab7-06b8d9b58d1f"
},
"source": [
"model_bridge_with_EHVI.model.botorch_acqf_class"
],
- "execution_count": 15,
+ "execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
@@ -916,31 +925,31 @@
"text/plain": "botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement"
},
"metadata": {},
- "execution_count": 15
+ "execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905087115,
- "executionStopTime": 1730905087362,
- "id": "operating-shelf",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "16727a51-337d-4715-bf51-9cb6637a950f"
},
- "originalKey": "22fefbbf-68a9-4d5d-ade9-9df425995c3b",
+ "id": "operating-shelf",
+ "originalKey": "1e980e3c-09f6-44c1-a79f-f59867de0c3e",
"outputsInitialized": true,
- "requestMsgId": "22fefbbf-68a9-4d5d-ade9-9df425995c3b",
- "serverExecutionDuration": 3.2659249845892
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916315097,
+ "executionStopTime": 1730916315308,
+ "serverExecutionDuration": 3.4662369871512,
+ "collapsed": false,
+ "requestMsgId": "1e980e3c-09f6-44c1-a79f-f59867de0c3e"
},
"source": [
- "type(model_bridge_with_EHVI.model.surrogate.model)"
+ "model_bridge_with_EHVI.model.surrogate.model.__class__"
],
- "execution_count": 16,
+ "execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
@@ -948,22 +957,22 @@
"text/plain": "botorch.models.gp_regression.SingleTaskGP"
},
"metadata": {},
- "execution_count": 16
+ "execution_count": 14
}
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "fatal-butterfly",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae"
},
+ "id": "fatal-butterfly",
"originalKey": "3ad7c4a7-fe19-44ad-938d-1be4f8b09bfb",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. "
@@ -972,15 +981,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "hearing-interface",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b"
},
+ "id": "hearing-interface",
"originalKey": "44adf1ce-6d3e-455d-b53c-32d3c42a843f",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 5. Utilizing `BoTorchModel` in generation strategies\n",
@@ -993,19 +1002,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905104120,
- "executionStopTime": 1730905104377,
- "id": "received-registration",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350"
},
- "originalKey": "d0303c24-98bb-4c89-87cb-fa32ff498bd4",
+ "id": "received-registration",
+ "originalKey": "4ee172c8-0648-418b-9968-647e8e916507",
"outputsInitialized": true,
- "requestMsgId": "d0303c24-98bb-4c89-87cb-fa32ff498bd4",
- "serverExecutionDuration": 2.4519630242139
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916316730,
+ "executionStopTime": 1730916316968,
+ "serverExecutionDuration": 2.2927720565349,
+ "collapsed": false,
+ "requestMsgId": "4ee172c8-0648-418b-9968-647e8e916507"
},
"source": [
"from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n",
@@ -1028,7 +1037,7 @@
" # No limit on how many generator runs will be produced\n",
" num_trials=-1,\n",
" model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n",
- " \"surrogate\": Surrogate(\n",
+ " \"surrogate_spec\": SurrogateSpec(\n",
" model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n",
" ),\n",
" \"botorch_acqf_class\": qLogNoisyExpectedImprovement,\n",
@@ -1037,21 +1046,21 @@
" ]\n",
")"
],
- "execution_count": 17,
+ "execution_count": 15,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
- "id": "logical-windsor",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "212c4543-220e-4605-8f72-5f86cf52f722"
},
+ "id": "logical-windsor",
"originalKey": "ba3783ee-3d88-4e44-ad07-77de3c50f84d",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:"
@@ -1060,19 +1069,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905105295,
- "executionStopTime": 1730905105570,
- "id": "viral-cheese",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6"
},
- "originalKey": "e8aa4013-eb6f-4f50-a5f0-f963369495ed",
+ "id": "viral-cheese",
+ "originalKey": "1b7d0cfc-f7cf-477d-b109-d34db9604938",
"outputsInitialized": true,
- "requestMsgId": "e8aa4013-eb6f-4f50-a5f0-f963369495ed",
- "serverExecutionDuration": 4.1721769375727
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916317751,
+ "executionStopTime": 1730916318153,
+ "serverExecutionDuration": 3.9581339806318,
+ "collapsed": false,
+ "requestMsgId": "1b7d0cfc-f7cf-477d-b109-d34db9604938"
},
"source": [
"experiment = get_branin_experiment(minimize=True)\n",
@@ -1080,13 +1089,13 @@
"assert len(experiment.trials) == 0\n",
"experiment.search_space"
],
- "execution_count": 18,
+ "execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:58:25] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n"
+ "[INFO 11-06 10:05:18] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n"
]
},
{
@@ -1095,22 +1104,22 @@
"text/plain": "SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])"
},
"metadata": {},
- "execution_count": 18
+ "execution_count": 16
}
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "incident-newspaper",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e"
},
+ "id": "incident-newspaper",
"originalKey": "df2e90f5-4132-4d87-989b-e6d47c748ddc",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 5a. Specifying `pending_observations`\n",
@@ -1122,19 +1131,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905107391,
- "executionStopTime": 1730905109351,
- "id": "casual-spread",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e"
},
- "originalKey": "36ca75ec-b37c-498c-a487-5652cd3dc34b",
+ "id": "casual-spread",
+ "originalKey": "fe7437c5-8834-46cc-94b2-91782d91ee96",
"outputsInitialized": true,
- "requestMsgId": "36ca75ec-b37c-498c-a487-5652cd3dc34b",
- "serverExecutionDuration": 1696.8911510194
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916318830,
+ "executionStopTime": 1730916321328,
+ "serverExecutionDuration": 2274.8276960338,
+ "collapsed": false,
+ "requestMsgId": "fe7437c5-8834-46cc-94b2-91782d91ee96"
},
"source": [
"for _ in range(10):\n",
@@ -1155,7 +1164,7 @@
"\n",
" print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")"
],
- "execution_count": 19,
+ "execution_count": 17,
"outputs": [
{
"output_type": "stream",
@@ -1175,7 +1184,14 @@
"output_type": "stream",
"name": "stdout",
"text": [
- "Completed trial #6, suggested by BoTorch.\nCompleted trial #7, suggested by BoTorch.\n"
+ "Completed trial #6, suggested by BoTorch.\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Completed trial #7, suggested by BoTorch.\n"
]
},
{
@@ -1197,15 +1213,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "circular-vermont",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c"
},
+ "id": "circular-vermont",
"originalKey": "6a78ef13-fbaa-4cae-934b-d57f5807fe25",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:"
@@ -1214,37 +1230,37 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905108832,
- "executionStopTime": 1730905109399,
- "id": "significant-particular",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1"
},
- "originalKey": "68d567be-27d6-4244-b22f-d6c53ed2d303",
+ "id": "significant-particular",
+ "originalKey": "b3160bc0-d5d1-45fa-bf62-4b9dd5778cac",
"outputsInitialized": true,
- "requestMsgId": "68d567be-27d6-4244-b22f-d6c53ed2d303",
- "serverExecutionDuration": 32.219067099504
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916319576,
+ "executionStopTime": 1730916321368,
+ "serverExecutionDuration": 35.789265064523,
+ "collapsed": false,
+ "requestMsgId": "b3160bc0-d5d1-45fa-bf62-4b9dd5778cac"
},
"source": [
"exp_to_df(experiment)"
],
- "execution_count": 20,
+ "execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[WARNING 11-06 06:58:29] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n"
+ "[WARNING 11-06 10:05:21] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n"
]
},
{
"output_type": "execute_result",
"data": {
- "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 79.581199 1.380743 12.280850\n1 1 1_0 COMPLETED ... 17.366840 6.989676 1.438049\n2 2 2_0 COMPLETED ... 61.299075 6.097525 7.568626\n3 3 3_0 COMPLETED ... 71.268812 -3.293570 4.231312\n4 4 4_0 COMPLETED ... 3.831238 -2.268755 10.230113\n5 5 5_0 COMPLETED ... 4.246417 -3.354258 10.886093\n6 6 6_0 COMPLETED ... 6.712767 9.467421 0.000000\n7 7 7_0 COMPLETED ... 17.508300 -5.000000 15.000000\n8 8 8_0 COMPLETED ... 2.507635 10.000000 2.251628\n9 9 9_0 COMPLETED ... 1.318731 8.983844 2.177537\n\n[10 rows x 7 columns]",
- "text/html": "
\n\n
\n \n \n | \n trial_index | \n arm_name | \n trial_status | \n generation_method | \n branin | \n x1 | \n x2 | \n
\n \n \n \n 0 | \n 0 | \n 0_0 | \n COMPLETED | \n Sobol | \n 79.581199 | \n 1.380743 | \n 12.280850 | \n
\n \n 1 | \n 1 | \n 1_0 | \n COMPLETED | \n Sobol | \n 17.366840 | \n 6.989676 | \n 1.438049 | \n
\n \n 2 | \n 2 | \n 2_0 | \n COMPLETED | \n Sobol | \n 61.299075 | \n 6.097525 | \n 7.568626 | \n
\n \n 3 | \n 3 | \n 3_0 | \n COMPLETED | \n Sobol | \n 71.268812 | \n -3.293570 | \n 4.231312 | \n
\n \n 4 | \n 4 | \n 4_0 | \n COMPLETED | \n Sobol | \n 3.831238 | \n -2.268755 | \n 10.230113 | \n
\n \n 5 | \n 5 | \n 5_0 | \n COMPLETED | \n BoTorch | \n 4.246417 | \n -3.354258 | \n 10.886093 | \n
\n \n 6 | \n 6 | \n 6_0 | \n COMPLETED | \n BoTorch | \n 6.712767 | \n 9.467421 | \n 0.000000 | \n
\n \n 7 | \n 7 | \n 7_0 | \n COMPLETED | \n BoTorch | \n 17.508300 | \n -5.000000 | \n 15.000000 | \n
\n \n 8 | \n 8 | \n 8_0 | \n COMPLETED | \n BoTorch | \n 2.507635 | \n 10.000000 | \n 2.251628 | \n
\n \n 9 | \n 9 | \n 9_0 | \n COMPLETED | \n BoTorch | \n 1.318731 | \n 8.983844 | \n 2.177537 | \n
\n \n
\n
",
+ "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 26.922506 -2.244023 5.435609\n1 1 1_0 COMPLETED ... 74.072517 3.535081 10.528676\n2 2 2_0 COMPLETED ... 5.610080 8.741262 3.706691\n3 3 3_0 COMPLETED ... 56.657623 -0.069164 12.199905\n4 4 4_0 COMPLETED ... 27.932704 0.862014 1.306074\n5 5 5_0 COMPLETED ... 5.423062 10.000000 4.868411\n6 6 6_0 COMPLETED ... 9.250452 10.000000 0.299753\n7 7 7_0 COMPLETED ... 308.129096 -5.000000 0.000000\n8 8 8_0 COMPLETED ... 17.607633 0.778687 5.717932\n9 9 9_0 COMPLETED ... 132.986209 1.451895 15.000000\n\n[10 rows x 7 columns]",
+ "text/html": "\n\n
\n \n \n | \n trial_index | \n arm_name | \n trial_status | \n generation_method | \n branin | \n x1 | \n x2 | \n
\n \n \n \n 0 | \n 0 | \n 0_0 | \n COMPLETED | \n Sobol | \n 26.922506 | \n -2.244023 | \n 5.435609 | \n
\n \n 1 | \n 1 | \n 1_0 | \n COMPLETED | \n Sobol | \n 74.072517 | \n 3.535081 | \n 10.528676 | \n
\n \n 2 | \n 2 | \n 2_0 | \n COMPLETED | \n Sobol | \n 5.610080 | \n 8.741262 | \n 3.706691 | \n
\n \n 3 | \n 3 | \n 3_0 | \n COMPLETED | \n Sobol | \n 56.657623 | \n -0.069164 | \n 12.199905 | \n
\n \n 4 | \n 4 | \n 4_0 | \n COMPLETED | \n Sobol | \n 27.932704 | \n 0.862014 | \n 1.306074 | \n
\n \n 5 | \n 5 | \n 5_0 | \n COMPLETED | \n BoTorch | \n 5.423062 | \n 10.000000 | \n 4.868411 | \n
\n \n 6 | \n 6 | \n 6_0 | \n COMPLETED | \n BoTorch | \n 9.250452 | \n 10.000000 | \n 0.299753 | \n
\n \n 7 | \n 7 | \n 7_0 | \n COMPLETED | \n BoTorch | \n 308.129096 | \n -5.000000 | \n 0.000000 | \n
\n \n 8 | \n 8 | \n 8_0 | \n COMPLETED | \n BoTorch | \n 17.607633 | \n 0.778687 | \n 5.717932 | \n
\n \n 9 | \n 9 | \n 9_0 | \n COMPLETED | \n BoTorch | \n 132.986209 | \n 1.451895 | \n 15.000000 | \n
\n \n
\n
",
"application/vnd.dataresource+json": {
"schema": {
"fields": [
@@ -1293,9 +1309,9 @@
"arm_name": "0_0",
"trial_status": "COMPLETED",
"generation_method": "Sobol",
- "branin": 79.5811993025,
- "x1": 1.3807432353,
- "x2": 12.2808498144
+ "branin": 26.9225058393,
+ "x1": -2.2440226376,
+ "x2": 5.4356087744
},
{
"index": 1,
@@ -1303,9 +1319,9 @@
"arm_name": "1_0",
"trial_status": "COMPLETED",
"generation_method": "Sobol",
- "branin": 17.3668397479,
- "x1": 6.9896756578,
- "x2": 1.4380489429
+ "branin": 74.0725171307,
+ "x1": 3.535081069,
+ "x2": 10.5286756391
},
{
"index": 2,
@@ -1313,9 +1329,9 @@
"arm_name": "2_0",
"trial_status": "COMPLETED",
"generation_method": "Sobol",
- "branin": 61.2990749448,
- "x1": 6.097525456,
- "x2": 7.5686264969
+ "branin": 5.6100798162,
+ "x1": 8.7412616471,
+ "x2": 3.7066908041
},
{
"index": 3,
@@ -1323,9 +1339,9 @@
"arm_name": "3_0",
"trial_status": "COMPLETED",
"generation_method": "Sobol",
- "branin": 71.268812081,
- "x1": -3.2935703546,
- "x2": 4.231311623
+ "branin": 56.6576230229,
+ "x1": -0.0691637676,
+ "x2": 12.1999046439
},
{
"index": 4,
@@ -1333,9 +1349,9 @@
"arm_name": "4_0",
"trial_status": "COMPLETED",
"generation_method": "Sobol",
- "branin": 3.8312383283,
- "x1": -2.2687551333,
- "x2": 10.2301133936
+ "branin": 27.9327040954,
+ "x1": 0.8620139305,
+ "x2": 1.3060741313
},
{
"index": 5,
@@ -1343,9 +1359,9 @@
"arm_name": "5_0",
"trial_status": "COMPLETED",
"generation_method": "BoTorch",
- "branin": 4.2464169491,
- "x1": -3.3542581623,
- "x2": 10.8860926765
+ "branin": 5.4230616409,
+ "x1": 10,
+ "x2": 4.8684112356
},
{
"index": 6,
@@ -1353,9 +1369,9 @@
"arm_name": "6_0",
"trial_status": "COMPLETED",
"generation_method": "BoTorch",
- "branin": 6.7127667696,
- "x1": 9.4674207228,
- "x2": 0
+ "branin": 9.2504522786,
+ "x1": 10,
+ "x2": 0.2997526514
},
{
"index": 7,
@@ -1363,9 +1379,9 @@
"arm_name": "7_0",
"trial_status": "COMPLETED",
"generation_method": "BoTorch",
- "branin": 17.5082995158,
+ "branin": 308.1290960116,
"x1": -5,
- "x2": 15
+ "x2": 0
},
{
"index": 8,
@@ -1373,9 +1389,9 @@
"arm_name": "8_0",
"trial_status": "COMPLETED",
"generation_method": "BoTorch",
- "branin": 2.5076350936,
- "x1": 10,
- "x2": 2.2516281613
+ "branin": 17.6076329851,
+ "x1": 0.7786866384,
+ "x2": 5.7179317285
},
{
"index": 9,
@@ -1383,30 +1399,30 @@
"arm_name": "9_0",
"trial_status": "COMPLETED",
"generation_method": "BoTorch",
- "branin": 1.3187305399,
- "x1": 8.983844442,
- "x2": 2.1775366828
+ "branin": 132.9862090134,
+ "x1": 1.451894724,
+ "x2": 15
}
]
}
},
"metadata": {},
- "execution_count": 20
+ "execution_count": 18
}
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "obvious-transparency",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783"
},
+ "id": "obvious-transparency",
"originalKey": "633c66af-a89f-4f03-a88b-866767d0a52f",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## 6. Customizing a `Surrogate` or `Acquisition`\n",
@@ -1419,19 +1435,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905111356,
- "executionStopTime": 1730905111601,
- "id": "organizational-balance",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af"
},
- "originalKey": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693",
+ "id": "organizational-balance",
+ "originalKey": "2949718a-8a4e-41e5-91ac-5b020eface47",
"outputsInitialized": true,
- "requestMsgId": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693",
- "serverExecutionDuration": 2.4916339898482
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916320585,
+ "executionStopTime": 1730916321384,
+ "serverExecutionDuration": 2.2059100447223,
+ "collapsed": false,
+ "requestMsgId": "2949718a-8a4e-41e5-91ac-5b020eface47"
},
"source": [
"from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform\n",
@@ -1451,21 +1467,21 @@
" ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:\n",
" ... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default"
],
- "execution_count": 21,
+ "execution_count": 19,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
- "id": "theoretical-horizon",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9"
},
+ "id": "theoretical-horizon",
"originalKey": "0ec8606d-9d5b-4bcb-ad7e-f54839ad6f9b",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"Then to use the new subclass in `BoTorchModel`, just specify `acquisition_class` argument along with `botorch_acqf_class` (to `BoTorchModel` directly or to `Models.BOTORCH_MODULAR`, which just passes the relevant arguments to `BoTorchModel` under the hood, as discussed in section 4):"
@@ -1474,19 +1490,19 @@
{
"cell_type": "code",
"metadata": {
- "collapsed": false,
- "executionStartTime": 1730905113131,
- "executionStopTime": 1730905113387,
- "id": "approximate-rolling",
- "isAgentGenerated": false,
- "language": "python",
"metadata": {
"originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc"
},
- "originalKey": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab",
+ "id": "approximate-rolling",
+ "originalKey": "e231ea1e-c70d-48dc-b6c6-1611c5ea1b26",
"outputsInitialized": true,
- "requestMsgId": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab",
- "serverExecutionDuration": 12.22031598445
+ "isAgentGenerated": false,
+ "language": "python",
+ "executionStartTime": 1730916321675,
+ "executionStopTime": 1730916321901,
+ "serverExecutionDuration": 12.351316981949,
+ "collapsed": false,
+ "requestMsgId": "e231ea1e-c70d-48dc-b6c6-1611c5ea1b26"
},
"source": [
"Models.BOTORCH_MODULAR(\n",
@@ -1496,13 +1512,13 @@
" botorch_acqf_class=MyAcquisitionFunctionClass,\n",
")"
],
- "execution_count": 22,
+ "execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
- "[INFO 11-06 06:58:33] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n"
+ "[INFO 11-06 10:05:21] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n"
]
},
{
@@ -1511,22 +1527,22 @@
"text/plain": "TorchModelBridge(model=BoTorchModel)"
},
"metadata": {},
- "execution_count": 22
+ "execution_count": 20
}
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "representative-implement",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256"
},
+ "id": "representative-implement",
"originalKey": "cdcfb2bc-3016-4681-9fff-407f28321c3f",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"To use a custom `Surrogate` subclass, pass the `surrogate` argument of that type:\n",
@@ -1542,15 +1558,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "framed-intermediate",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361"
},
+ "id": "framed-intermediate",
"originalKey": "ff03d674-f584-403f-ba65-f1bab921845b",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"------"
@@ -1559,15 +1575,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "metropolitan-feedback",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0"
},
+ "id": "metropolitan-feedback",
"originalKey": "f71fcfa1-fc59-4bfb-84d6-b94ea5298bfa",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## Appendix 1: Methods available on `BoTorchModel`\n",
@@ -1590,15 +1606,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "possible-transsexual",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12"
},
+ "id": "possible-transsexual",
"originalKey": "91cedde4-8911-441f-af05-eb124581cbbc",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## Appendix 2: Default surrogate models and acquisition functions\n",
@@ -1617,15 +1633,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "continuous-strain",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6"
},
+ "id": "continuous-strain",
"originalKey": "c8b0f933-8df6-479b-aa61-db75ca877624",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"## Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A\n",
@@ -1636,15 +1652,15 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "broadband-voice",
- "isAgentGenerated": false,
- "language": "markdown",
"metadata": {
"originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc"
},
+ "id": "broadband-voice",
"originalKey": "4d82f49a-3a8b-42f0-a4f5-5c079b793344",
+ "showInput": false,
"outputsInitialized": false,
- "showInput": false
+ "isAgentGenerated": false,
+ "language": "markdown"
},
"source": [
"The two options for handling this error are:\n",