diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index 74518e3ebe7..d4fbeda78a2 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -14,7 +14,7 @@ from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Models -from ax.models.torch.botorch_modular.model import SurrogateSpec +from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.service.scheduler import SchedulerOptions from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import LogExpectedImprovement diff --git a/ax/benchmark/tests/methods/test_methods.py b/ax/benchmark/tests/methods/test_methods.py index 9c63e2705e0..0dc1aa60cee 100644 --- a/ax/benchmark/tests/methods/test_methods.py +++ b/ax/benchmark/tests/methods/test_methods.py @@ -65,7 +65,7 @@ def _test_mbm_acquisition(self, scheduler_options: SchedulerOptions) -> None: self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient) surrogate_spec = model_kwargs["surrogate_spec"] self.assertEqual( - surrogate_spec.botorch_model_class.__name__, + surrogate_spec.model_configs[0].botorch_model_class.__name__, "SingleTaskGP", ) diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 9d903ad8c9e..27f0ea653a8 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -59,10 +59,8 @@ from ax.models.random.sobol import SobolGenerator from ax.models.random.uniform import UniformGenerator from ax.models.torch.botorch import BotorchModel -from ax.models.torch.botorch_modular.model import ( - BoTorchModel as ModularBoTorchModel, - SurrogateSpec, -) +from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel +from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.models.torch.cbo_sac import SACBO from ax.utils.common.kwargs import ( diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 418c4bcf889..20f17d7ce70 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -10,6 +10,7 @@ import torch from ax.core.observation import ObservationFeatures +from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.modelbridge.discrete import DiscreteModelBridge from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import ( @@ -27,8 +28,8 @@ from ax.models.random.sobol import SobolGenerator from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel from ax.utils.common.kwargs import get_function_argument_names from ax.utils.common.testutils import TestCase @@ -99,7 +100,8 @@ def test_SAASBO(self) -> None: SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( - saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP + saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class, + SaasFullyBayesianSingleTaskGP, ) @mock_botorch_optimize @@ -459,9 +461,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: self.assertIsInstance(mtgp, TorchModelBridge) self.assertIsInstance(mtgp.model, BoTorchModel) self.assertEqual(mtgp.model.acquisition_class, Acquisition) - self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) + is_moo = isinstance( + exp.optimization_config, MultiObjectiveOptimizationConfig + ) + if is_moo: + self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) + models = mtgp.model.surrogate.model.models + else: + models = [mtgp.model.surrogate.model] - for model in mtgp.model.surrogate.model.models: + for model in models: self.assertIsInstance( model, SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index a5e8f84aedc..ebd04a2f2e3 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -10,7 +10,6 @@ import warnings from collections import OrderedDict from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field from typing import Any import numpy.typing as npt @@ -23,7 +22,7 @@ get_rounding_func, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( check_outcome_dataset_match, choose_botorch_acqf_class, @@ -37,51 +36,10 @@ from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction from botorch.models.deterministic import FixedSingleSampleModel -from botorch.models.model import Model -from botorch.models.transforms.input import InputTransform -from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset -from gpytorch.kernels.kernel import Kernel -from gpytorch.likelihoods import Likelihood -from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor -@dataclass(frozen=True) -class SurrogateSpec: - """ - Fields in the SurrogateSpec dataclass correspond to arguments in - ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which - outcomes the Surrogate is responsible for modeling. - When ``BotorchModel.fit`` is called, these fields will be used to construct the - requisite Surrogate objects. - If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. - """ - - botorch_model_class: type[Model] | None = None - botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) - - mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood - mll_kwargs: dict[str, Any] = field(default_factory=dict) - - covar_module_class: type[Kernel] | None = None - covar_module_kwargs: dict[str, Any] | None = None - - likelihood_class: type[Likelihood] | None = None - likelihood_kwargs: dict[str, Any] | None = None - - input_transform_classes: list[type[InputTransform]] | None = None - input_transform_options: dict[str, dict[str, Any]] | None = None - - outcome_transform_classes: list[type[OutcomeTransform]] | None = None - outcome_transform_options: dict[str, dict[str, Any]] | None = None - - allow_batched_models: bool = True - - outcomes: list[str] = field(default_factory=list) - - class BoTorchModel(TorchModel, Base): """**All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha @@ -227,27 +185,19 @@ def fit( # If a surrogate has not been constructed, construct it. if self._surrogate is None: - if (spec := self.surrogate_spec) is not None: - self._surrogate = Surrogate( - botorch_model_class=spec.botorch_model_class, - model_options=spec.botorch_model_kwargs, - mll_class=spec.mll_class, - mll_options=spec.mll_kwargs, - covar_module_class=spec.covar_module_class, - covar_module_options=spec.covar_module_kwargs, - likelihood_class=spec.likelihood_class, - likelihood_options=spec.likelihood_kwargs, - input_transform_classes=spec.input_transform_classes, - input_transform_options=spec.input_transform_options, - outcome_transform_classes=spec.outcome_transform_classes, - outcome_transform_options=spec.outcome_transform_options, - allow_batched_models=spec.allow_batched_models, - ) + if self.surrogate_spec is not None: + self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec) else: self._surrogate = Surrogate() # Fit the surrogate. - self.surrogate.model_options.update(additional_model_inputs) + for config in self.surrogate.surrogate_spec.model_configs: + config.model_options.update(additional_model_inputs) + for ( + config_list + ) in self.surrogate.surrogate_spec.metric_to_model_configs.values(): + for config in config_list: + config.model_options.update(additional_model_inputs) self.surrogate.fit( datasets=datasets, search_space_digest=search_space_digest, diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 0fab257ffca..802ed179b70 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -9,9 +9,11 @@ from __future__ import annotations import inspect +import warnings from collections import OrderedDict from collections.abc import Sequence from copy import deepcopy +from dataclasses import dataclass, field from logging import Logger from typing import Any @@ -33,6 +35,7 @@ choose_model_class, convert_to_block_design, fit_botorch_model, + ModelConfig, subset_state_dict, use_model_list, ) @@ -50,11 +53,11 @@ _argparse_type_encoder, checked_cast, checked_cast_optional, + not_none, ) from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP -from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ( ChainedInputTransform, InputPerturbation, @@ -68,6 +71,7 @@ from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from pyre_extensions import none_throws from torch import Tensor NOT_YET_FIT_MSG = ( @@ -268,6 +272,171 @@ def _set_formatted_inputs( formatted_model_inputs[input_name] = input_class(**input_options) +def _raise_deprecation_warning( + is_surrogate: bool = False, + **kwargs: Any, +) -> bool: + """Raise deprecation warnings for deprecated arguments. + + Args: + is_surrogate: A boolean indicating whether the warning is called from + Surrogate. + + Returns: + A boolean indicating whether any deprecation warnings were raised. + """ + msg = "{k} is deprecated and will be removed in a future version. " + if is_surrogate: + msg += "Please specify {k} via `surrogate_spec.model_configs`." + else: + msg += "Please specify {k} via `model_configs`." + warnings_raised = False + default_is_dict = {"botorch_model_kwargs", "mll_kwargs"} + for k, v in kwargs.items(): + should_raise = False + if k in default_is_dict: + if v != {}: + should_raise = True + elif (v is not None and k != "mll_class") or ( + k == "mll_class" and v is not ExactMarginalLogLikelihood + ): + should_raise = True + if should_raise: + warnings.warn( + msg.format(k=k), + DeprecationWarning, + stacklevel=3, + ) + warnings_raised = True + return warnings_raised + + +def get_model_config_from_deprecated_args( + botorch_model_class: type[Model] | None, + model_options: dict[str, Any] | None, + mll_class: type[MarginalLogLikelihood] | None, + mll_options: dict[str, Any] | None, + outcome_transform_classes: list[type[OutcomeTransform]] | None, + outcome_transform_options: dict[str, dict[str, Any]] | None, + input_transform_classes: list[type[InputTransform]] | None, + input_transform_options: dict[str, dict[str, Any]] | None, + covar_module_class: type[Kernel] | None, + covar_module_options: dict[str, Any] | None, + likelihood_class: type[Likelihood] | None, + likelihood_options: dict[str, Any] | None, +) -> ModelConfig: + """Construct a ModelConfig from deprecated arguments.""" + model_config_kwargs = { + "botorch_model_class": botorch_model_class, + "model_options": (model_options or {}).copy(), + "mll_class": mll_class, + "mll_options": (mll_options or {}).copy(), + "outcome_transform_classes": outcome_transform_classes, + "outcome_transform_options": outcome_transform_options, + "input_transform_classes": input_transform_classes, + "input_transform_options": input_transform_options, + "covar_module_class": covar_module_class, + "covar_module_options": covar_module_options, + "likelihood_class": likelihood_class, + "likelihood_options": likelihood_options, + } + model_config_kwargs = { + k: v for k, v in model_config_kwargs.items() if v is not None + } + # pyre-fixme [6]: Incompatible parameter type [6]: In call + # `ModelConfig.__init__`, for 1st positional argument, expected + # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], + # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], + # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, + # Model]], Type[Likelihood], Type[Kernel]]`. + return ModelConfig(**model_config_kwargs) + + +@dataclass(frozen=True) +class SurrogateSpec: + """ + Fields in the SurrogateSpec dataclass correspond to arguments in + ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which + outcomes the Surrogate is responsible for modeling. + When ``BotorchModel.fit`` is called, these fields will be used to construct the + requisite Surrogate objects. + If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. + """ + + botorch_model_class: type[Model] | None = None + botorch_model_kwargs: dict[str, Any] = field(default_factory=dict) + + mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood + mll_kwargs: dict[str, Any] = field(default_factory=dict) + + covar_module_class: type[Kernel] | None = None + covar_module_kwargs: dict[str, Any] | None = None + + likelihood_class: type[Likelihood] | None = None + likelihood_kwargs: dict[str, Any] | None = None + + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = None + + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] | None = None + + allow_batched_models: bool = True + + model_configs: list[ModelConfig] = field(default_factory=list) + metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) + outcomes: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + warnings_raised = _raise_deprecation_warning( + is_surrogate=False, + botorch_model_class=self.botorch_model_class, + botorch_model_kwargs=self.botorch_model_kwargs, + mll_class=self.mll_class, + mll_kwargs=self.mll_kwargs, + outcome_transform_classes=self.outcome_transform_classes, + outcome_transform_options=self.outcome_transform_options, + input_transform_classes=self.input_transform_classes, + input_transform_options=self.input_transform_options, + covar_module_class=self.covar_module_class, + covar_module_options=self.covar_module_kwargs, + likelihood_class=self.likelihood_class, + likelihood_options=self.likelihood_kwargs, + ) + if len(self.model_configs) == 0: + model_config = get_model_config_from_deprecated_args( + botorch_model_class=self.botorch_model_class, + model_options=self.botorch_model_kwargs, + mll_class=self.mll_class, + mll_options=self.mll_kwargs, + outcome_transform_classes=self.outcome_transform_classes, + outcome_transform_options=self.outcome_transform_options, + input_transform_classes=self.input_transform_classes, + input_transform_options=self.input_transform_options, + covar_module_class=self.covar_module_class, + covar_module_options=self.covar_module_kwargs, + likelihood_class=self.likelihood_class, + likelihood_options=self.likelihood_kwargs, + ) + # re-initialize with the non-deprecated arguments + self.__init__( + model_configs=[model_config], + metric_to_model_configs=self.metric_to_model_configs, + allow_batched_models=self.allow_batched_models, + outcomes=self.outcomes, + ) + elif warnings_raised: + raise UserInputError( + "model_configs and deprecated arguments were both specified. " + "Please use model_configs and remove deprecated arguments." + ) + if len(self.model_configs) > 1 or any( + len(model_config) > 1 + for model_config in self.metric_to_model_configs.values() + ): + raise NotImplementedError("Only one model config per metric is supported.") + + class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under @@ -282,15 +451,20 @@ class Surrogate(Base): BoTorch model. If None is provided a model class will be selected (either one for all outcomes or a ModelList with separate models for each outcome) will be selected automatically based off the datasets at `construct` time. + This argument is deprecated in favor of model_configs. model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any additional kwargs passed into ``BoTorchModel.fit``. + This argument is deprecated in favor of model_configs. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. - mll_options: Dictionary of options / kwargs for the MLL. + This argument is deprecated in favor of model_configs. + mll_options: Dictionary of options / kwargs for the MLL. This argument is + deprecated in favor of model_configs. outcome_transform_classes: List of BoTorch outcome transforms classes. Passed down to the BoTorch ``Model``. Multiple outcome transforms can be chained - together using ``ChainedOutcomeTransform``. + together using ``ChainedOutcomeTransform``. This argument is deprecated in + favor of model_configs. outcome_transform_options: Outcome transform classes kwargs. The keys are class string names and the values are dictionaries of outcome transform kwargs. For example, @@ -299,10 +473,12 @@ class string names and the values are dictionaries of outcome transform outcome_transform_options = { "Standardize": {"m": 1}, ` - For more options see `botorch/models/transforms/outcome.py`. + For more options see `botorch/models/transforms/outcome.py`. This argument + is deprecated in favor of model_configs. input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. + This argument is deprecated in favor of model_configs. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform kwargs. For example, @@ -314,26 +490,36 @@ class string names and the values are dictionaries of input transform } ` For more input options see `botorch/models/transforms/input.py`. + This argument is deprecated in favor of model_configs. covar_module_class: Covariance module class. This gets initialized after parsing the ``covar_module_options`` in ``covar_module_argparse``, and gets passed to the model constructor as ``covar_module``. - covar_module_options: Covariance module kwargs. + This argument is deprecated in favor of model_configs. + covar_module_options: Covariance module kwargs. This argument is deprecated + in favor of model_configs. likelihood: ``Likelihood`` class. This gets initialized with ``likelihood_options`` and gets passed to the model constructor. - likelihood_options: Likelihood options. + This argument is deprecated in favor of model_configs. + likelihood_options: Likelihood options. This argument is deprecated in favor + of model_configs. + model_configs: List of model configs. Each model config is a specification of + a model. These should be used in favor of the above deprecated arguments. + metric_to_model_configs: Dictionary mapping metric names to a list of model + configs for that metric. allow_batched_models: Set to true to fit the models in a batch if supported. Set to false to fit individual models to each metric in a loop. """ def __init__( self, + surrogate_spec: SurrogateSpec | None = None, botorch_model_class: type[Model] | None = None, model_options: dict[str, Any] | None = None, mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, mll_options: dict[str, Any] | None = None, - outcome_transform_classes: Sequence[type[OutcomeTransform]] | None = None, + outcome_transform_classes: list[type[OutcomeTransform]] | None = None, outcome_transform_options: dict[str, dict[str, Any]] | None = None, - input_transform_classes: Sequence[type[InputTransform]] | None = None, + input_transform_classes: list[type[InputTransform]] | None = None, input_transform_options: dict[str, dict[str, Any]] | None = None, covar_module_class: type[Kernel] | None = None, covar_module_options: dict[str, Any] | None = None, @@ -341,21 +527,49 @@ def __init__( likelihood_options: dict[str, Any] | None = None, allow_batched_models: bool = True, ) -> None: - self.botorch_model_class = botorch_model_class - # Copying model options to avoid mutating the original dict. - # We later update it with any additional kwargs passed into `BoTorchModel.fit`. - self.model_options: dict[str, Any] = (model_options or {}).copy() - self.mll_class = mll_class - self.mll_options: dict[str, Any] = mll_options or {} - self.outcome_transform_classes = outcome_transform_classes - self.outcome_transform_options: dict[str, Any] = outcome_transform_options or {} - self.input_transform_classes = input_transform_classes - self.input_transform_options: dict[str, Any] = input_transform_options or {} - self.covar_module_class = covar_module_class - self.covar_module_options: dict[str, Any] = covar_module_options or {} - self.likelihood_class = likelihood_class - self.likelihood_options: dict[str, Any] = likelihood_options or {} - self.allow_batched_models = allow_batched_models + warnings_raised = _raise_deprecation_warning( + is_surrogate=True, + botorch_model_class=botorch_model_class, + model_options=model_options, + mll_class=mll_class, + mll_options=mll_options, + outcome_transform_classes=outcome_transform_classes, + outcome_transform_options=outcome_transform_options, + input_transform_classes=input_transform_classes, + input_transform_options=input_transform_options, + covar_module_class=covar_module_class, + covar_module_options=covar_module_options, + likelihood_class=likelihood_class, + likelihood_options=likelihood_options, + ) + # check if surrogate_spec is provided + if surrogate_spec is None: + # create surrogate spec from deprecated arguments + model_config = get_model_config_from_deprecated_args( + botorch_model_class=botorch_model_class, + model_options=model_options, + mll_class=mll_class, + mll_options=mll_options, + outcome_transform_classes=outcome_transform_classes, + outcome_transform_options=outcome_transform_options, + input_transform_classes=input_transform_classes, + input_transform_options=input_transform_options, + covar_module_class=covar_module_class, + covar_module_options=covar_module_options, + likelihood_class=likelihood_class, + likelihood_options=likelihood_options, + ) + surrogate_spec = SurrogateSpec( + model_configs=[model_config], allow_batched_models=allow_batched_models + ) + + elif warnings_raised: + raise UserInputError( + "model_configs and deprecated arguments were both specified. " + "Please use model_configs and remove deprecated arguments." + ) + + self.surrogate_spec: SurrogateSpec = surrogate_spec # Store the last dataset used to fit the model for a given metric(s). # If the new dataset is identical, we will skip model fitting for that metric. # The keys are `tuple(dataset.outcome_names)`. @@ -375,13 +589,7 @@ def __init__( self._model: Model | None = None def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__}" - f" botorch_model_class={self.botorch_model_class} " - f"mll_class={self.mll_class} " - f"outcome_transform_classes={self.outcome_transform_classes} " - f"input_transform_classes={self.input_transform_classes} " - ) + return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>" @property def model(self) -> Model: @@ -404,9 +612,7 @@ def Xs(self) -> list[Tensor]: training_data = self.training_data Xs = [] for dataset in training_data: - if self.botorch_model_class == PairwiseGP and isinstance( - dataset, RankingDataset - ): + if isinstance(dataset, RankingDataset): # directly accessing the d-dim X tensor values # instead of the augmented 2*d-dim dataset.X from RankingDataset Xi = checked_cast(SliceContainer, dataset._X).values @@ -431,7 +637,8 @@ def _construct_model( self, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - botorch_model_class: type[Model], + model_config: ModelConfig, + default_botorch_model_class: type[Model], state_dict: OrderedDict[str, Tensor] | None, refit: bool, ) -> Model: @@ -446,19 +653,24 @@ def _construct_model( multi-output case, where training data is formatted with just one X and concatenated Ys). search_space_digest: Search space digest used to set up model arguments. - botorch_model_class: ``Model`` class to be used as the underlying - BoTorch model. + model_config: The model_config. + default_botorch_model_class: The default ``Model`` class to be used as the + underlying BoTorch model, if the model_config does not specify one. state_dict: Optional state dict to load. This should be subsetted for the current submodel being constructed. refit: Whether to re-optimize model parameters. """ outcome_names = tuple(dataset.outcome_names) + botorch_model_class = ( + model_config.botorch_model_class or default_botorch_model_class + ) if self._should_reuse_last_model( dataset=dataset, botorch_model_class=botorch_model_class ): return self._submodels[outcome_names] formatted_model_inputs = submodel_input_constructor( botorch_model_class, # Do not pass as kwarg since this is used to dispatch. + model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=self, @@ -469,7 +681,9 @@ def _construct_model( model.load_state_dict(state_dict) if state_dict is None or refit: fit_botorch_model( - model=model, mll_class=self.mll_class, mll_options=self.mll_options + model=model, + mll_class=model_config.mll_class, + mll_options=model_config.mll_options, ) self._submodels[outcome_names] = model self._last_datasets[outcome_names] = dataset @@ -539,14 +753,23 @@ def fit( # To determine whether to use ModelList under the hood, we need to check for # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. - botorch_model_class = self.botorch_model_class or choose_model_class( - datasets=datasets, search_space_digest=search_space_digest - ) - + if ( + len(self.surrogate_spec.model_configs) == 1 + and self.surrogate_spec.model_configs[0].botorch_model_class is None + ): + default_botorch_model_class = choose_model_class( + datasets=datasets, search_space_digest=search_space_digest + ) + else: + default_botorch_model_class = self.surrogate_spec.model_configs[ + 0 + ].botorch_model_class should_use_model_list = use_model_list( datasets=datasets, - botorch_model_class=botorch_model_class, - allow_batched_models=self.allow_batched_models, + botorch_model_class=not_none(default_botorch_model_class), + model_configs=self.surrogate_spec.model_configs, + allow_batched_models=self.surrogate_spec.allow_batched_models, + metric_to_model_configs=self.surrogate_spec.metric_to_model_configs, ) if not should_use_model_list and len(datasets) > 1: @@ -564,10 +787,30 @@ def fit( ) else: submodel_state_dict = state_dict + model_config = None + if len(self.surrogate_spec.metric_to_model_configs) > 0: + # if metric_to_model_configs is not empty, then + # we are using a model list and each dataset + # should have only one outcome. + if len(dataset.outcome_names) > 1: + raise ValueError( + "Each dataset should have only one outcome when " + "metric_to_model_configs is specified." + ) + model_config_list = self.surrogate_spec.metric_to_model_configs.get( + dataset.outcome_names[0] + ) + + # TODO: add support for automated model selection + if model_config_list is not None: + model_config = model_config_list[0] + if model_config is None: + model_config = self.surrogate_spec.model_configs[0] model = self._construct_model( dataset=dataset, search_space_digest=search_space_digest, - botorch_model_class=botorch_model_class, + model_config=model_config, + default_botorch_model_class=not_none(default_botorch_model_class), state_dict=submodel_state_dict, refit=refit, ) @@ -727,24 +970,10 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: """Serialize attributes of this surrogate, to be passed back to it as kwargs on reinstantiation. """ - return { - "botorch_model_class": self.botorch_model_class, - "model_options": self.model_options, - "mll_class": self.mll_class, - "mll_options": self.mll_options, - "outcome_transform_classes": self.outcome_transform_classes, - "outcome_transform_options": self.outcome_transform_options, - "input_transform_classes": self.input_transform_classes, - "input_transform_options": self.input_transform_options, - "covar_module_class": self.covar_module_class, - "covar_module_options": self.covar_module_options, - "likelihood_class": self.likelihood_class, - "likelihood_options": self.likelihood_options, - "allow_batched_models": self.allow_batched_models, - } + return {"surrogate_spec": self.surrogate_spec} def _extract_construct_input_transform_args( - self, search_space_digest: SearchSpaceDigest + self, model_config: ModelConfig, search_space_digest: SearchSpaceDigest ) -> tuple[Sequence[type[InputTransform]] | None, dict[str, dict[str, Any]]]: """ Extracts input transform classes and input transform options that will @@ -777,19 +1006,19 @@ def _extract_construct_input_transform_args( InputPerturbation ] - if self.input_transform_classes is not None: + if model_config.input_transform_classes is not None: # TODO: Support mixing with user supplied transforms. raise NotImplementedError( "User supplied input transforms are not supported " "in robust optimization." ) else: - submodel_input_transform_classes = self.input_transform_classes - submodel_input_transform_options = self.input_transform_options + submodel_input_transform_classes = model_config.input_transform_classes + submodel_input_transform_options = model_config.input_transform_options return ( submodel_input_transform_classes, - submodel_input_transform_options, + none_throws(submodel_input_transform_options), ) @property @@ -811,6 +1040,7 @@ def outcomes(self, value: list[str]) -> None: @submodel_input_constructor.register(Model) def _submodel_input_constructor_base( botorch_model_class: type[Model], + model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, @@ -819,6 +1049,7 @@ def _submodel_input_constructor_base( Args: botorch_model_class: The BoTorch model class to instantiate. + model_config: The model config. dataset: The training data for the model. search_space_digest: Search space digest used to set up model arguments. surrogate: A reference to the surrogate that created the model. @@ -836,12 +1067,12 @@ def _submodel_input_constructor_base( input_transform_classes, input_transform_options, ) = surrogate._extract_construct_input_transform_args( - search_space_digest=search_space_digest + model_config=model_config, search_space_digest=search_space_digest ) formatted_model_inputs = botorch_model_class.construct_inputs( training_data=dataset, - **surrogate.model_options, + **model_config.model_options, **model_kwargs_from_ss, ) @@ -851,14 +1082,18 @@ def _submodel_input_constructor_base( inputs=[ ( "covar_module", - surrogate.covar_module_class, - surrogate.covar_module_options, + model_config.covar_module_class, + model_config.covar_module_options, + ), + ( + "likelihood", + model_config.likelihood_class, + model_config.likelihood_options, ), - ("likelihood", surrogate.likelihood_class, surrogate.likelihood_options), ( "outcome_transform", - surrogate.outcome_transform_classes, - surrogate.outcome_transform_options, + model_config.outcome_transform_classes, + model_config.outcome_transform_options, ), ( "input_transform", @@ -880,6 +1115,7 @@ def _submodel_input_constructor_base( @submodel_input_constructor.register(MultiTaskGP) def _submodel_input_constructor_mtgp( botorch_model_class: type[Model], + model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, @@ -888,6 +1124,7 @@ def _submodel_input_constructor_mtgp( raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, + model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=surrogate, diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 26a273ddd4e..bd207a08973 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -9,6 +9,7 @@ import warnings from collections import OrderedDict from collections.abc import Sequence +from dataclasses import dataclass, field from logging import Logger from typing import Any @@ -34,8 +35,13 @@ from botorch.models.model import Model, ModelList from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian +from gpytorch.kernels.kernel import Kernel +from gpytorch.likelihoods import Likelihood +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from pyre_extensions import none_throws from torch import Tensor @@ -44,20 +50,115 @@ logger: Logger = get_logger(__name__) +@dataclass +class ModelConfig: + """Configuration for the BoTorch Model used in Surrogate. + + Args: + botorch_model_class: ``Model`` class to be used as the underlying + BoTorch model. If None is provided a model class will be selected (either + one for all outcomes or a ModelList with separate models for each outcome) + will be selected automatically based off the datasets at `construct` time. + model_options: Dictionary of options / kwargs for the BoTorch + ``Model`` constructed during ``Surrogate.fit``. + Note that the corresponding attribute will later be updated to include any + additional kwargs passed into ``BoTorchModel.fit``. + mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. + This argument is deprecated in favor of model_configs. + mll_options: Dictionary of options / kwargs for the MLL. + outcome_transform_classes: List of BoTorch outcome transforms classes. Passed + down to the BoTorch ``Model``. Multiple outcome transforms can be chained + together using ``ChainedOutcomeTransform``. + outcome_transform_options: Outcome transform classes kwargs. The keys are + class string names and the values are dictionaries of outcome transform + kwargs. For example, + ` + outcome_transform_classes = [Standardize] + outcome_transform_options = { + "Standardize": {"m": 1}, + ` + For more options see `botorch/models/transforms/outcome.py`. + input_transform_classes: List of BoTorch input transforms classes. + Passed down to the BoTorch ``Model``. Multiple input transforms + will be chained together using ``ChainedInputTransform``. + input_transform_options: Input transform classes kwargs. The keys are + class string names and the values are dictionaries of input transform + kwargs. For example, + ` + input_transform_classes = [Normalize, Round] + input_transform_options = { + "Normalize": {"d": 3}, + "Round": {"integer_indices": [0], "categorical_features": {1: 2}}, + } + ` + For more input options see `botorch/models/transforms/input.py`. + covar_module_class: Covariance module class. This gets initialized after + parsing the ``covar_module_options`` in ``covar_module_argparse``, + and gets passed to the model constructor as ``covar_module``. + covar_module_options: Covariance module kwargs. + in favor of model_configs. + likelihood: ``Likelihood`` class. This gets initialized with + ``likelihood_options`` and gets passed to the model constructor. + This argument is deprecated in favor of model_configs. + likelihood_options: Likelihood options. + """ + + botorch_model_class: type[Model] | None = None + model_options: dict[str, Any] = field(default_factory=dict) + mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood + mll_options: dict[str, Any] = field(default_factory=dict) + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = field( + default_factory=dict + ) + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict) + covar_module_class: type[Kernel] | None = None + covar_module_options: dict[str, Any] = field(default_factory=dict) + likelihood_class: type[Likelihood] | None = None + likelihood_options: dict[str, Any] = field(default_factory=dict) + + def use_model_list( datasets: Sequence[SupervisedDataset], botorch_model_class: type[Model], + model_configs: list[ModelConfig] | None = None, + metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> bool: - if issubclass(botorch_model_class, MultiTaskGP): - # We currently always wrap multi-task models into `ModelListGP`. + model_configs = model_configs or [] + metric_to_model_configs = metric_to_model_configs or {} + if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1: + # There is only one outcome, so we can use a single model. + return False + elif ( + len(model_configs) > 1 + or len(metric_to_model_configs) > 0 + or any(len(model_config) for model_config in metric_to_model_configs.values()) + ): + # There are multiple outcomes and outcomes might be modeled with different + # models return True - elif issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): + # Otherwise, the same model class is used for all outcomes. + # Determine what the model class is. + if len(model_configs) > 0: + botorch_model_class = ( + model_configs[0].botorch_model_class or botorch_model_class + ) + if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): # SAAS models do not support multiple outcomes. # Use model list if there are multiple outcomes. return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 + elif issubclass(botorch_model_class, MultiTaskGP): + # We wrap multi-task models into `ModelListGP` when there are + # multiple outcomes. + return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 elif len(datasets) == 1: - # Just one outcome, can use single model. + # This method is called before multiple datasets are merged into + # one if using a batched model. If there is one dataset here, + # there should be a reason that a single model should be used: + # e.g. a contextual model, where we want to jointly model the metric + # each context (and context-level metrics are different outcomes). return False elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all( torch.equal(datasets[0].X, ds.X) for ds in datasets[1:] diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index c8cf1ddba7a..797719a0f5c 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -21,12 +21,12 @@ from ax.models.torch.botorch_modular.model import ( BoTorchModel, choose_botorch_acqf_class, - SurrogateSpec, ) -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ( choose_model_class, construct_acquisition_and_optimizer_options, + ModelConfig, ) from ax.models.torch.utils import _filter_X_observed from ax.models.torch_base import TorchOptConfig @@ -327,7 +327,21 @@ def test_fit(self, mock_fit: Mock) -> None: mock_fit.assert_called_with( dataset=self.block_design_training_data[0], search_space_digest=self.mf_search_space_digest, - botorch_model_class=SingleTaskMultiFidelityGP, + model_config=ModelConfig( + botorch_model_class=None, + model_options={}, + mll_class=ExactMarginalLogLikelihood, + mll_options={}, + input_transform_classes=None, + input_transform_options={}, + outcome_transform_classes=None, + outcome_transform_options={}, + covar_module_class=None, + covar_module_options={}, + likelihood_class=None, + likelihood_options={}, + ), + default_botorch_model_class=SingleTaskMultiFidelityGP, state_dict=None, refit=True, ) @@ -704,30 +718,17 @@ def test_evaluate_acquisition_function( def test_surrogate_model_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel( - surrogate_spec=SurrogateSpec( - botorch_model_kwargs={"some_option": "some_value"} - ) + surrogate_spec = SurrogateSpec( + botorch_model_kwargs={"some_option": "some_value"} ) + model = BoTorchModel(surrogate_spec=surrogate_spec) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) mock_init.assert_called_with( - botorch_model_class=None, - model_options={"some_option": "some_value"}, - mll_class=ExactMarginalLogLikelihood, - mll_options={}, - covar_module_class=None, - covar_module_options=None, - likelihood_class=None, - likelihood_options=None, - input_transform_classes=None, - input_transform_options=None, - outcome_transform_classes=None, - outcome_transform_options=None, - allow_batched_models=True, + surrogate_spec=surrogate_spec, ) @mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) @@ -736,26 +737,15 @@ def test_surrogate_model_options_propagation( def test_surrogate_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel(surrogate_spec=SurrogateSpec(allow_batched_models=False)) + surrogate_spec = SurrogateSpec(allow_batched_models=False) + model = BoTorchModel(surrogate_spec=surrogate_spec) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) mock_init.assert_called_with( - botorch_model_class=None, - model_options={}, - mll_class=ExactMarginalLogLikelihood, - mll_options={}, - covar_module_class=None, - covar_module_options=None, - likelihood_class=None, - likelihood_options=None, - input_transform_classes=None, - input_transform_options=None, - outcome_transform_classes=None, - outcome_transform_options=None, - allow_batched_models=False, + surrogate_spec=surrogate_spec, ) @mock.patch( diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 2d6b12b7ef0..eadce0951bf 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -18,8 +18,17 @@ from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate -from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model +from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel +from ax.models.torch.botorch_modular.surrogate import ( + _extract_model_kwargs, + Surrogate, + SurrogateSpec, +) +from ax.models.torch.botorch_modular.utils import ( + choose_model_class, + fit_botorch_model, + ModelConfig, +) from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast @@ -34,10 +43,10 @@ from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood from botorch.models.transforms.input import InputPerturbation, Normalize -from botorch.models.transforms.outcome import Standardize +from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import SupervisedDataset from gpytorch.constraints import GreaterThan, Interval -from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood from pyre_extensions import assert_is_instance, none_throws @@ -197,7 +206,7 @@ def _get_surrogate( mll_options = None if use_outcome_transform: - outcome_transform_classes = [Standardize] + outcome_transform_classes: list[type[OutcomeTransform]] = [Standardize] outcome_transform_options = {"Standardize": {"m": 1}} else: outcome_transform_classes = None @@ -216,9 +225,16 @@ def _get_surrogate( def test_init(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) - self.assertEqual(surrogate.botorch_model_class, botorch_model_class) - self.assertEqual(surrogate.mll_class, self.mll_class) - self.assertTrue(surrogate.allow_batched_models) # True by default + self.assertEqual( + surrogate.surrogate_spec.model_configs[0].botorch_model_class, + botorch_model_class, + ) + self.assertEqual( + surrogate.surrogate_spec.model_configs[0].mll_class, self.mll_class + ) + self.assertTrue( + surrogate.surrogate_spec.allow_batched_models + ) # True by default def test_clone_reset(self) -> None: surrogate = self._get_surrogate(botorch_model_class=SingleTaskGP)[0] @@ -432,7 +448,8 @@ def test_construct_model(self) -> None: Surrogate()._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=Model, + model_config=ModelConfig(), + default_botorch_model_class=Model, state_dict=None, refit=True, ) @@ -446,7 +463,8 @@ def test_construct_model(self) -> None: model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=botorch_model_class, + model_config=surrogate.surrogate_spec.model_configs[0], + default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, ) @@ -471,7 +489,8 @@ def test_construct_model(self) -> None: new_model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=botorch_model_class, + model_config=surrogate.surrogate_spec.model_configs[0], + default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, ) @@ -485,7 +504,8 @@ def test_construct_model(self) -> None: surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=SingleTaskGPWithDifferentConstructor, + model_config=ModelConfig(), + default_botorch_model_class=SingleTaskGPWithDifferentConstructor, state_dict=None, refit=True, ) @@ -497,14 +517,22 @@ def test_construct_model(self) -> None: ) @mock_botorch_optimize - def test_construct_custom_model(self) -> None: + def test_construct_custom_model(self, use_model_config: bool = False) -> None: # Test error for unsupported covar_module and likelihood. - surrogate = Surrogate( - botorch_model_class=SingleTaskGPWithDifferentConstructor, - mll_class=self.mll_class, - covar_module_class=RBFKernel, - likelihood_class=FixedNoiseGaussianLikelihood, - ) + model_config_kwargs: dict[str, Any] = { + "botorch_model_class": SingleTaskGPWithDifferentConstructor, + "mll_class": self.mll_class, + "covar_module_class": RBFKernel, + "likelihood_class": FixedNoiseGaussianLikelihood, + } + if use_model_config: + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(**model_config_kwargs)] + ) + ) + else: + surrogate = Surrogate(**model_config_kwargs) with self.assertRaisesRegex(UserInputError, "does not support"): surrogate.fit( self.training_data, @@ -512,14 +540,22 @@ def test_construct_custom_model(self) -> None: ) # Pass custom options to a SingleTaskGP and make sure they are used noise_constraint = Interval(1e-6, 1e-1) - surrogate = Surrogate( - botorch_model_class=SingleTaskGP, - mll_class=LeaveOneOutPseudoLikelihood, - covar_module_class=RBFKernel, - covar_module_options={"ard_num_dims": 3}, - likelihood_class=GaussianLikelihood, - likelihood_options={"noise_constraint": noise_constraint}, - ) + model_config_kwargs = { + "botorch_model_class": SingleTaskGP, + "mll_class": LeaveOneOutPseudoLikelihood, + "covar_module_class": RBFKernel, + "covar_module_options": {"ard_num_dims": 3}, + "likelihood_class": GaussianLikelihood, + "likelihood_options": {"noise_constraint": noise_constraint}, + } + if use_model_config: + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(**model_config_kwargs)] + ) + ) + else: + surrogate = Surrogate(**model_config_kwargs) surrogate.fit( self.training_data, search_space_digest=self.search_space_digest, @@ -532,10 +568,68 @@ def test_construct_custom_model(self) -> None: model.likelihood.noise_covar.raw_noise_constraint.__dict__, noise_constraint.__dict__, ) - self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood) + self.assertEqual( + surrogate.surrogate_spec.model_configs[0].mll_class, + LeaveOneOutPseudoLikelihood, + ) self.assertEqual(type(model.covar_module), RBFKernel) self.assertEqual(model.covar_module.ard_num_dims, 3) + def test_construct_custom_model_with_config(self) -> None: + self.test_construct_custom_model(use_model_config=True) + + def test_construct_model_with_metric_to_model_configs(self) -> None: + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + metric_to_model_configs={ + "metric": [ModelConfig()], + "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], + }, + model_configs=[ModelConfig(covar_module_class=LinearKernel)], + ) + ) + training_data = self.training_data + [ + SupervisedDataset( + X=self.Xs[0], + # Note: using 1d Y does not match the 2d TorchOptConfig + Y=self.Ys[0], + feature_names=self.feature_names, + outcome_names=[f"metric{i}"], + ) + for i in range(2, 5) + ] + surrogate.fit( + datasets=training_data, search_space_digest=self.search_space_digest + ) + # test model follows metric_to_model_configs for + # first two metrics + self.assertIsInstance(surrogate.model, ModelListGP) + submodels = surrogate.model.models + self.assertEqual(len(submodels), 4) + for m in submodels: + self.assertIsInstance(m, SingleTaskGP) + self.assertIsInstance(surrogate.model.models[1].covar_module, ScaleKernel) + self.assertIsInstance( + surrogate.model.models[1].covar_module.base_kernel, MaternKernel + ) + self.assertIsInstance(surrogate.model.models[0].covar_module, RBFKernel) + # test model use model_configs for the third metric + self.assertIsInstance(surrogate.model.models[2].covar_module, LinearKernel) + + def test_multiple_model_configs_error(self) -> None: + with self.assertRaisesRegex( + NotImplementedError, "Only one model config per metric is supported." + ): + SurrogateSpec( + model_configs=[ModelConfig(), ModelConfig()], + ) + with self.assertRaisesRegex( + NotImplementedError, "Only one model config per metric is supported." + ): + SurrogateSpec( + metric_to_model_configs={"metric": [ModelConfig(), ModelConfig()]}, + ) + @mock_botorch_optimize @patch(f"{SURROGATE_PATH}.predict_from_model") def test_predict(self, mock_predict: Mock) -> None: @@ -646,9 +740,7 @@ def test_best_out_of_sample_point(self) -> None: def test_serialize_attributes_as_kwargs(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) - expected = { - k: v for k, v in surrogate.__dict__.items() if not k.startswith("_") - } + expected = {"surrogate_spec": surrogate.surrogate_spec} self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected) @mock_botorch_optimize @@ -677,7 +769,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.input_transform_classes = [Normalize] + surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.training_data, @@ -830,11 +922,12 @@ def setUp(self) -> None: ) def test_init(self) -> None: + model_config = self.surrogate.surrogate_spec.model_configs[0] self.assertEqual( - [self.surrogate.botorch_model_class] * 2, + [model_config.botorch_model_class] * 2, [*self.botorch_submodel_class_per_outcome.values()], ) - self.assertEqual(self.surrogate.mll_class, self.mll_class) + self.assertEqual(model_config.mll_class, self.mll_class) with self.assertRaisesRegex( ValueError, "BoTorch `Model` has not yet been constructed" ): @@ -849,7 +942,9 @@ def test_init(self) -> None: def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock ) -> None: - self.surrogate.model_options.update({"output_tasks": [2]}) + self.surrogate.surrogate_spec.model_configs[0].model_options.update( + {"output_tasks": [2]} + ) for fixed_noise in (False, True): mock_fit.reset_mock() mock_MTGP_construct_inputs.reset_mock() @@ -940,10 +1035,14 @@ def test_fit( self.assertIsNone(surrogate._model) # Should instantiate mll and `fit_gpytorch_mll` when `state_dict` # is `None`. - # pyre-ignore[6]: Incompatible parameter type: In call - # `issubclass`, for 1st positional argument, expected - # `Type[typing.Any]` but got `Optional[Type[Model]]`. - is_mtgp = issubclass(surrogate.botorch_model_class, MultiTaskGP) + + is_mtgp = issubclass( + # pyre-ignore[6]: Incompatible parameter type: In call + # `issubclass`, for 1st positional argument, expected + # `Type[typing.Any]` but got `Optional[Type[Model]]`. + surrogate.surrogate_spec.model_configs[0].botorch_model_class, + MultiTaskGP, + ) search_space_digest = ( self.multi_task_search_space_digest if is_mtgp @@ -1097,25 +1196,6 @@ def test_with_botorch_transforms(self) -> None: ) ) - def test_serialize_attributes_as_kwargs(self) -> None: - # TODO[mpolson64] Reimplement this when serialization has been sorted out - pass - # expected = self.surrogate.__dict__ - # # The two attributes below don't need to be saved as part of state, - # # so we remove them from the expected dict. - # for attr_name in ( - # "botorch_model_class", - # "model_options", - # "covar_module_class", - # "covar_module_options", - # "likelihood_class", - # "likelihood_options", - # "outcome_transform", - # "input_transform", - # ): - # expected.pop(attr_name) - # self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(), expected) - @mock_botorch_optimize def test_construct_custom_model(self) -> None: noise_constraint = Interval(1e-4, 10.0) @@ -1144,7 +1224,10 @@ def test_construct_custom_model(self) -> None: ) models = checked_cast(ModelListGP, surrogate._model).models self.assertEqual(len(models), 2) - self.assertEqual(surrogate.mll_class, ExactMarginalLogLikelihood) + self.assertEqual( + surrogate.surrogate_spec.model_configs[0].mll_class, + ExactMarginalLogLikelihood, + ) # Make sure we properly copied the transforms. self.assertNotEqual( id(models[0].input_transform), id(models[1].input_transform) @@ -1197,7 +1280,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.input_transform_classes = [Normalize] + surrogate.surrogate_spec.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.supervised_training_data, diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 19d263f5fa1..19e05552b68 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -306,7 +306,7 @@ def test_use_model_list(self) -> None: botorch_model_class=SingleTaskGP, ) ) - self.assertTrue( + self.assertFalse( use_model_list( datasets=self.supervised_datasets, botorch_model_class=MultiTaskGP ) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index d39e46801dd..140ad1b33bd 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -44,8 +44,8 @@ TransitionCriterion, TrialBasedCriterion, ) -from ax.models.torch.botorch_modular.model import SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec +from ax.models.torch.botorch_modular.utils import ModelConfig from ax.storage.json_store.decoders import ( batch_trial_from_json, botorch_component_from_json, @@ -229,7 +229,7 @@ def object_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - elif _class in (SurrogateSpec, Surrogate): + elif _class in (SurrogateSpec, Surrogate, ModelConfig): if "input_transform" in object_json: ( input_transform_classes_json, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index a798c30b1fe..3f0aad31e66 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -96,8 +96,9 @@ TransitionCriterion, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec +from ax.models.torch.botorch_modular.utils import ModelConfig from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType @@ -330,6 +331,7 @@ "AuxiliaryExperimentCheck": AuxiliaryExperimentCheck, "Models": Models, "ModelRegistryBase": ModelRegistryBase, + "ModelConfig": ModelConfig, "ModelSpec": ModelSpec, "MultiObjective": MultiObjective, "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 24c8a634e57..a2e17f95acb 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -600,11 +600,11 @@ def test_encode_decode_surrogate_spec(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) - org_as_dict = dataclasses.asdict(org_object) - converted_as_dict = dataclasses.asdict(converted_object) + org_as_dict = dataclasses.asdict(org_object)["model_configs"][0] + converted_as_dict = dataclasses.asdict(converted_object)["model_configs"][0] # Covar module kwargs will fail comparison. Manually compare. - org_covar_kwargs = org_as_dict.pop("covar_module_kwargs") - converted_covar_kwargs = converted_as_dict.pop("covar_module_kwargs") + org_covar_kwargs = org_as_dict.pop("covar_module_options") + converted_covar_kwargs = converted_as_dict.pop("covar_module_options") self.assertEqual(org_covar_kwargs.keys(), converted_covar_kwargs.keys()) for k in org_covar_kwargs: org_ = org_covar_kwargs[k] diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 6aaf96ab9de..8e130319fe2 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -28,7 +28,7 @@ from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.service.scheduler import SchedulerOptions from ax.utils.common.constants import Keys from ax.utils.testing.core_stubs import ( @@ -189,7 +189,11 @@ def get_sobol_gpei_benchmark_method() -> BenchmarkMethod: model=Models.BOTORCH_MODULAR, num_trials=-1, model_kwargs={ - "surrogate": Surrogate(SingleTaskGP), + "surrogate": Surrogate( + surrogate_spec=SurrogateSpec( + botorch_model_class=SingleTaskGP + ) + ), # TODO: tests should better reflect defaults and not # re-implement this logic. "botorch_acqf_class": qNoisyExpectedImprovement, diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 7d90a339df1..007375df242 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -99,9 +99,9 @@ ) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec +from ax.models.torch.botorch_modular.model import BoTorchModel from ax.models.torch.botorch_modular.sebo import SEBOAcquisition -from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index 275530f2f3c..dc8f72147ec 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -1,13 +1,44 @@ { + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5, "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "about-preview", "metadata": { - "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + "metadata": { + "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + }, + "id": "about-preview", + "originalKey": "f4e8ae18-2aa3-4943-a15a-29851889445c", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916291451, + "executionStopTime": 1730916298337, + "serverExecutionDuration": 4531.2523420434, + "collapsed": false, + "requestMsgId": "f4e8ae18-2aa3-4943-a15a-29851889445c", + "customOutput": null }, - "outputs": [], "source": [ "from typing import Any, Dict, Optional, Tuple, Type\n", "\n", @@ -18,7 +49,8 @@ "\n", "# Ax wrappers for BoTorch components\n", "from ax.models.torch.botorch_modular.model import BoTorchModel\n", - "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", + "from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec\n", + "from ax.models.torch.botorch_modular.utils import ModelConfig\n", "\n", "# Experiment examination utilities\n", "from ax.service.utils.report_utils import exp_to_df\n", @@ -39,13 +71,37 @@ "# BoTorch components\n", "from botorch.models.model import Model\n", "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "I1106 100452.333 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "I1106 100452.334 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" + ] + } ] }, { "cell_type": "markdown", - "id": "northern-affairs", "metadata": { - "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + "metadata": { + "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + }, + "id": "northern-affairs", + "originalKey": "c9a665ca-497e-4d7c-bbb5-1b9f8d1d311c", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "# Setup and Usage of BoTorch Models in Ax\n", @@ -70,9 +126,16 @@ }, { "cell_type": "markdown", - "id": "pending-support", "metadata": { - "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + "metadata": { + "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + }, + "id": "pending-support", + "originalKey": "4706d02e-6b3f-4161-9e08-f5a31328b1d1", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 1. Quick-start example\n", @@ -82,41 +145,89 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "parental-sending", "metadata": { - "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + "metadata": { + "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + }, + "id": "parental-sending", + "originalKey": "20f25ded-5aae-47ee-955e-a2d5a2a1fe09", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916294801, + "executionStopTime": 1730916298389, + "serverExecutionDuration": 22.605526028201, + "collapsed": false, + "requestMsgId": "20f25ded-5aae-47ee-955e-a2d5a2a1fe09", + "customOutput": null }, - "outputs": [], "source": [ "experiment = get_branin_experiment(with_trial=True)\n", "data = get_branin_data(trials=[experiment.trials[0]])" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:04:56] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "rough-somerset", "metadata": { - "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + "metadata": { + "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + }, + "id": "rough-somerset", + "originalKey": "c0806cce-a1d3-41b8-96fc-678aa3c9dd92", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916295849, + "executionStopTime": 1730916299900, + "serverExecutionDuration": 852.73489891551, + "collapsed": false, + "requestMsgId": "c0806cce-a1d3-41b8-96fc-678aa3c9dd92" }, - "outputs": [], "source": [ "# `Models` automatically selects a model + model bridge combination.\n", "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", " data=data,\n", - " surrogate=Surrogate(SingleTaskGP), # Optional, will use default if unspecified\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " ), # Optional, will use default if unspecified\n", " botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified\n", ")" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:04:57] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + } ] }, { "cell_type": "markdown", - "id": "hairy-wiring", "metadata": { - "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + "metadata": { + "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + }, + "id": "hairy-wiring", + "originalKey": "46f5c2c7-400d-4d8d-b0b9-a241657b173f", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." @@ -124,22 +235,49 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "consecutive-summary", "metadata": { - "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + "metadata": { + "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + }, + "id": "consecutive-summary", + "originalKey": "f64e9d2e-bfd4-47da-8292-dbe7e70cbe1f", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916299852, + "executionStopTime": 1730916300305, + "serverExecutionDuration": 233.20194100961, + "collapsed": false, + "requestMsgId": "f64e9d2e-bfd4-47da-8292-dbe7e70cbe1f" }, - "outputs": [], "source": [ "generator_run = model_bridge_with_GPEI.gen(n=1)\n", "generator_run.arms[0]" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "Arm(parameters={'x1': 10.0, 'x2': 15.0})" + }, + "metadata": {}, + "execution_count": 4 + } ] }, { "cell_type": "markdown", - "id": "diverse-richards", "metadata": { - "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + "metadata": { + "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + }, + "id": "diverse-richards", + "originalKey": "804bac30-db07-4444-98a2-7a5f05007495", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "-----\n", @@ -151,21 +289,35 @@ }, { "cell_type": "markdown", - "id": "grand-committee", "metadata": { - "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + "metadata": { + "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + }, + "id": "grand-committee", + "originalKey": "31b54ce5-2590-4617-b10c-d24ed3cce51d", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 2. BoTorchModel = Surrogate + Acquisition\n", "\n", - "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." + "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The Surrogate is defined by a `SurrogateSpec`. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." ] }, { "cell_type": "markdown", - "id": "thousand-blanket", "metadata": { - "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + "metadata": { + "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + }, + "id": "thousand-blanket", + "originalKey": "4a4e006e-07fa-4d63-8b9a-31b67075e40e", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 2A. Example that uses defaults and requires no options\n", @@ -175,12 +327,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "changing-xerox", "metadata": { - "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + "metadata": { + "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + }, + "id": "changing-xerox", + "originalKey": "fa86552a-0b80-4040-a0c4-61a0de37bdc1", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916302730, + "executionStopTime": 1730916304031, + "serverExecutionDuration": 1.7747740494087, + "collapsed": false, + "requestMsgId": "fa86552a-0b80-4040-a0c4-61a0de37bdc1" }, - "outputs": [], "source": [ "# The surrogate is not specified, so it will be auto-selected\n", "# during `model.fit`.\n", @@ -188,17 +349,30 @@ "\n", "# The acquisition class is not specified, so it will be\n", "# auto-selected during `model.gen` or `model.evaluate_acquisition`\n", - "GPEI_model = BoTorchModel(surrogate=Surrogate(SingleTaskGP))\n", + "GPEI_model = BoTorchModel(\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " )\n", + ")\n", "\n", "# Both the surrogate and acquisition class will be auto-selected.\n", "GPEI_model = BoTorchModel()" - ] + ], + "execution_count": 5, + "outputs": [] }, { "cell_type": "markdown", - "id": "lovely-mechanics", "metadata": { - "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + "metadata": { + "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + }, + "id": "lovely-mechanics", + "originalKey": "7b9fae38-fe5d-4e5b-8b5f-2953c1ef09d2", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 2B. Example with all the options\n", @@ -207,23 +381,36 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "twenty-greek", "metadata": { - "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + "metadata": { + "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + }, + "id": "twenty-greek", + "originalKey": "8d824e37-b087-4bab-9b16-4354e9509df7", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916305930, + "executionStopTime": 1730916306168, + "serverExecutionDuration": 2.6916969800368, + "collapsed": false, + "requestMsgId": "8d824e37-b087-4bab-9b16-4354e9509df7" }, - "outputs": [], "source": [ "model = BoTorchModel(\n", " # Optional `Surrogate` specification to use instead of default\n", - " surrogate=Surrogate(\n", - " # BoTorch `Model` type\n", - " botorch_model_class=SingleTaskGP,\n", - " # Optional, MLL class with which to optimize model parameters\n", - " mll_class=ExactMarginalLogLikelihood,\n", - " # Optional, dictionary of keyword arguments to underlying\n", - " # BoTorch `Model` constructor\n", - " model_options={},\n", + " surrogate_spec=SurrogateSpec(\n", + " model_configs=[\n", + " ModelConfig(\n", + " # BoTorch `Model` type\n", + " botorch_model_class=SingleTaskGP,\n", + " # Optional, MLL class with which to optimize model parameters\n", + " mll_class=ExactMarginalLogLikelihood,\n", + " # Optional, dictionary of keyword arguments to underlying\n", + " # BoTorch `Model` constructor\n", + " model_options={},\n", + " )\n", + " ]\n", " ),\n", " # Optional BoTorch `AcquisitionFunction` to use instead of default\n", " botorch_acqf_class=qLogExpectedImprovement,\n", @@ -238,13 +425,22 @@ " refit_on_cv=False,\n", " warm_start_refit=True,\n", ")" - ] + ], + "execution_count": 6, + "outputs": [] }, { "cell_type": "markdown", - "id": "fourth-material", "metadata": { - "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + "metadata": { + "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + }, + "id": "fourth-material", + "originalKey": "7140bb19-09b4-4abe-951d-53902ae07833", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 2C. `Surrogate` and `Acquisition` Q&A\n", @@ -258,9 +454,16 @@ }, { "cell_type": "markdown", - "id": "violent-course", "metadata": { - "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + "metadata": { + "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + }, + "id": "violent-course", + "originalKey": "71f92895-874d-4fc7-ae87-a5519b18d1a0", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" @@ -268,12 +471,19 @@ }, { "cell_type": "markdown", - "id": "unlike-football", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", - "showInput": false + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "showInput": false + }, + "id": "unlike-football", + "originalKey": "4af8afa2-5056-46be-b7b9-428127e668cc", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 3a. Making a `Surrogate` from BoTorch `Model`:\n", @@ -281,19 +491,28 @@ "\n", "If your `Model` is not a `ModelListGP`, the steps to set it up as a `Surrogate` are:\n", "1. Implement a [`construct_inputs` class method](https://github.com/pytorch/botorch/blob/main/botorch/models/model.py#L143). The purpose of this method is to produce arguments to a particular model from a standardized set of inputs passed to BoTorch `Model`-s from [`Surrogate.construct`](https://github.com/facebook/Ax/blob/main/ax/models/torch/botorch_modular/surrogate.py#L148) in Ax. It should accept training data in form of a `SupervisedDataset` container and optionally other keyword arguments and produce a dictionary of arguments to `__init__` of the `Model`. See [`SingleTaskMultiFidelityGP.construct_inputs`](https://github.com/pytorch/botorch/blob/5b3172f3daa22f6ea2f6f4d1d0a378a9518dcd8d/botorch/models/gp_regression_fidelity.py#L131) for an example.\n", - "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via `model_options` argument to `Surrogate`." + "2. Pass any additional needed keyword arguments for the `Model` constructor (that cannot be constructed from the training data and other arguments to `construct_inputs`) via the `model_options` argument to `ModelConfig` in `SurrogateSpec`." ] }, { "cell_type": "code", - "execution_count": null, - "id": "dynamic-university", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + }, + "id": "dynamic-university", + "originalKey": "746fc2a3-0e0e-4ab4-84d9-32434eb1fc34", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916308518, + "executionStopTime": 1730916308769, + "serverExecutionDuration": 2.4644429795444, + "collapsed": false, + "requestMsgId": "746fc2a3-0e0e-4ab4-84d9-32434eb1fc34" }, - "outputs": [], "source": [ "from botorch.models.model import Model\n", "from botorch.utils.datasets import SupervisedDataset\n", @@ -317,18 +536,31 @@ " }\n", "\n", "\n", - "surrogate = Surrogate(\n", - " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", - " # Optional dict of additional keyword arguments to `MyModelClass`\n", - " model_options={},\n", + "surrogate_spec = SurrogateSpec(\n", + " model_configs=[\n", + " ModelConfig(\n", + " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", + " # Optional dict of additional keyword arguments to `MyModelClass`\n", + " model_options={},\n", + " )\n", + " ]\n", ")" - ] + ], + "execution_count": 7, + "outputs": [] }, { "cell_type": "markdown", - "id": "otherwise-context", "metadata": { - "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + "metadata": { + "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + }, + "id": "otherwise-context", + "originalKey": "5a27fd2c-4c4c-41fe-a634-f6d0ec4f1666", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." @@ -336,9 +568,16 @@ }, { "cell_type": "markdown", - "id": "northern-invite", "metadata": { - "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + "metadata": { + "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + }, + "id": "northern-invite", + "originalKey": "df06d02b-95cb-4d34-aac6-773231f1a129", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" @@ -346,32 +585,50 @@ }, { "cell_type": "markdown", - "id": "surrounded-denial", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", - "showInput": false + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "showInput": false + }, + "id": "surrounded-denial", + "originalKey": "d4861847-b757-4fcd-9f35-ba258080812c", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Steps to set up any `AcquisitionFunction` in Ax are:\n", - "1. Define an input constructor function. The purpose of this method is to produce arguments to an acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n", + "1. Define an input constructor function. The purpose of this method is to produce arguments to a acquisition function from a standardized set of inputs passed to BoTorch `AcquisitionFunction`-s from `Acquisition.__init__` in Ax. For example, see [`construct_inputs_qEHVI`](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L477), which creates a fairly complex set of arguments needed by `qExpectedHypervolumeImprovement` –– a popular multi-objective optimization acquisition function offered in Ax and BoTorch. For more examples, see this collection in BoTorch: [botorch/acquisition/input_constructors.py](https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py) \n", " 1. Note that the new input constructor needs to be decorated with `@acqf_input_constructor(AcquisitionFunctionClass)` to register it.\n", - "2. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n", - "3. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`." + "3. Specify the BoTorch `AcquisitionFunction` class as `botorch_acqf_class` to `BoTorchModel`\n", + "4. (Optional) Pass any additional keyword arguments to acquisition function constructor or to the optimizer function via `acquisition_options` argument to `BoTorchModel`." ] }, { "cell_type": "code", - "execution_count": null, - "id": "interested-search", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + }, + "id": "interested-search", + "originalKey": "f188f40b-64ba-4b0c-b216-f3dea8c7465e", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916310518, + "executionStopTime": 1730916310772, + "serverExecutionDuration": 4.9752569757402, + "collapsed": false, + "requestMsgId": "f188f40b-64ba-4b0c-b216-f3dea8c7465e", + "customOutput": null }, - "outputs": [], "source": [ + "from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse\n", "from botorch.acquisition.acquisition import AcquisitionFunction\n", "from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict\n", "from botorch.utils.datasets import SupervisedDataset\n", @@ -393,6 +650,7 @@ " pass\n", "\n", "\n", + "\n", "# 2-3. Specifying `botorch_acqf_class` and `acquisition_options`\n", "BoTorchModel(\n", " botorch_acqf_class=MyAcquisitionFunctionClass,\n", @@ -405,13 +663,31 @@ " \"optimizer_options\": {\"sequential\": False},\n", " },\n", ")" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "BoTorchModel" + }, + "metadata": {}, + "execution_count": 8 + } ] }, { "cell_type": "markdown", - "id": "metallic-imaging", "metadata": { - "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + "metadata": { + "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + }, + "id": "metallic-imaging", + "originalKey": "b057722d-b8ca-47dd-b2c8-1ff4a71c4863", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." @@ -419,9 +695,16 @@ }, { "cell_type": "markdown", - "id": "descending-australian", "metadata": { - "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + "metadata": { + "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + }, + "id": "descending-australian", + "originalKey": "a7406f13-1468-487d-ac5e-7d2a45394850", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 4. Using `Models.BOTORCH_MODULAR` \n", @@ -433,49 +716,123 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "attached-border", "metadata": { - "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + "metadata": { + "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + }, + "id": "attached-border", + "originalKey": "052cf2e4-8de0-4ec3-a3f9-478194b10928", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916311983, + "executionStopTime": 1730916312395, + "serverExecutionDuration": 202.78578903526, + "collapsed": false, + "requestMsgId": "052cf2e4-8de0-4ec3-a3f9-478194b10928" }, - "outputs": [], "source": [ "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", " data=data,\n", ")\n", "model_bridge_with_GPEI.gen(1)" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:05:12] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "GeneratorRun(1 arms, total weight 1.0)" + }, + "metadata": {}, + "execution_count": 9 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "powerful-gamma", "metadata": { - "originalKey": "89930a31-e058-434b-b587-181931e247b6" + "metadata": { + "originalKey": "89930a31-e058-434b-b587-181931e247b6" + }, + "id": "powerful-gamma", + "originalKey": "b7f924fe-f3d9-4211-b402-421f4c90afe5", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916312432, + "executionStopTime": 1730916312657, + "serverExecutionDuration": 3.1334219966084, + "collapsed": false, + "requestMsgId": "b7f924fe-f3d9-4211-b402-421f4c90afe5" }, - "outputs": [], "source": [ "model_bridge_with_GPEI.model.botorch_acqf_class" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.acquisition.logei.qLogNoisyExpectedImprovement" + }, + "metadata": {}, + "execution_count": 10 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "improved-replication", "metadata": { - "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + "metadata": { + "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + }, + "id": "improved-replication", + "originalKey": "942f1817-8d40-48f8-8725-90c25a079e4c", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916312847, + "executionStopTime": 1730916313093, + "serverExecutionDuration": 3.410067060031, + "collapsed": false, + "requestMsgId": "942f1817-8d40-48f8-8725-90c25a079e4c" }, - "outputs": [], "source": [ - "model_bridge_with_GPEI.model.surrogate.botorch_model_class" + "model_bridge_with_GPEI.model.surrogate.model.__class__" + ], + "execution_count": 11, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.models.gp_regression.SingleTaskGP" + }, + "metadata": {}, + "execution_count": 11 + } ] }, { "cell_type": "markdown", - "id": "connected-sheet", "metadata": { - "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + "metadata": { + "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + }, + "id": "connected-sheet", + "originalKey": "f5c0adbd-00a6-428d-810f-1e7ed0954b08", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" @@ -483,12 +840,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "documentary-jurisdiction", "metadata": { - "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + "metadata": { + "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + }, + "id": "documentary-jurisdiction", + "originalKey": "9c64c497-f663-42a6-aa48-1f1f2ae2b80b", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916314009, + "executionStopTime": 1730916314736, + "serverExecutionDuration": 518.53136904538, + "collapsed": false, + "requestMsgId": "9c64c497-f663-42a6-aa48-1f1f2ae2b80b" }, - "outputs": [], "source": [ "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", " experiment=get_branin_experiment_with_multi_objective(\n", @@ -497,37 +863,116 @@ " data=get_branin_data_multi_objective(),\n", ")\n", "model_bridge_with_EHVI.gen(1)" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:05:14] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:05:14] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:05:14] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "GeneratorRun(1 arms, total weight 1.0)" + }, + "metadata": {}, + "execution_count": 12 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "changed-maintenance", "metadata": { - "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + "metadata": { + "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + }, + "id": "changed-maintenance", + "originalKey": "ab6e84ac-2a55-4f48-9ab7-06b8d9b58d1f", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916314586, + "executionStopTime": 1730916314842, + "serverExecutionDuration": 3.3097150735557, + "collapsed": false, + "requestMsgId": "ab6e84ac-2a55-4f48-9ab7-06b8d9b58d1f" }, - "outputs": [], "source": [ "model_bridge_with_EHVI.model.botorch_acqf_class" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement" + }, + "metadata": {}, + "execution_count": 13 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "operating-shelf", "metadata": { - "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + "metadata": { + "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + }, + "id": "operating-shelf", + "originalKey": "1e980e3c-09f6-44c1-a79f-f59867de0c3e", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916315097, + "executionStopTime": 1730916315308, + "serverExecutionDuration": 3.4662369871512, + "collapsed": false, + "requestMsgId": "1e980e3c-09f6-44c1-a79f-f59867de0c3e" }, - "outputs": [], "source": [ - "model_bridge_with_EHVI.model.surrogate.botorch_model_class" + "model_bridge_with_EHVI.model.surrogate.model.__class__" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.models.gp_regression.SingleTaskGP" + }, + "metadata": {}, + "execution_count": 14 + } ] }, { "cell_type": "markdown", - "id": "fatal-butterfly", "metadata": { - "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + "metadata": { + "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + }, + "id": "fatal-butterfly", + "originalKey": "3ad7c4a7-fe19-44ad-938d-1be4f8b09bfb", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " @@ -535,9 +980,16 @@ }, { "cell_type": "markdown", - "id": "hearing-interface", "metadata": { - "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + "metadata": { + "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + }, + "id": "hearing-interface", + "originalKey": "44adf1ce-6d3e-455d-b53c-32d3c42a843f", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 5. Utilizing `BoTorchModel` in generation strategies\n", @@ -549,12 +1001,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "received-registration", "metadata": { - "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + "metadata": { + "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + }, + "id": "received-registration", + "originalKey": "4ee172c8-0648-418b-9968-647e8e916507", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916316730, + "executionStopTime": 1730916316968, + "serverExecutionDuration": 2.2927720565349, + "collapsed": false, + "requestMsgId": "4ee172c8-0648-418b-9968-647e8e916507" }, - "outputs": [], "source": [ "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", @@ -576,19 +1037,30 @@ " # No limit on how many generator runs will be produced\n", " num_trials=-1,\n", " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", - " \"surrogate\": Surrogate(SingleTaskGP),\n", + " \"surrogate_spec\": SurrogateSpec(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " ),\n", " \"botorch_acqf_class\": qLogNoisyExpectedImprovement,\n", " },\n", " ),\n", " ]\n", ")" - ] + ], + "execution_count": 15, + "outputs": [] }, { "cell_type": "markdown", - "id": "logical-windsor", "metadata": { - "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + "metadata": { + "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + }, + "id": "logical-windsor", + "originalKey": "ba3783ee-3d88-4e44-ad07-77de3c50f84d", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" @@ -596,24 +1068,58 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "viral-cheese", "metadata": { - "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + "metadata": { + "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + }, + "id": "viral-cheese", + "originalKey": "1b7d0cfc-f7cf-477d-b109-d34db9604938", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916317751, + "executionStopTime": 1730916318153, + "serverExecutionDuration": 3.9581339806318, + "collapsed": false, + "requestMsgId": "1b7d0cfc-f7cf-477d-b109-d34db9604938" }, - "outputs": [], "source": [ "experiment = get_branin_experiment(minimize=True)\n", "\n", "assert len(experiment.trials) == 0\n", "experiment.search_space" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 10:05:18] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])" + }, + "metadata": {}, + "execution_count": 16 + } ] }, { "cell_type": "markdown", - "id": "incident-newspaper", "metadata": { - "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + "metadata": { + "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + }, + "id": "incident-newspaper", + "originalKey": "df2e90f5-4132-4d87-989b-e6d47c748ddc", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "## 5a. Specifying `pending_observations`\n", @@ -624,12 +1130,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "casual-spread", "metadata": { - "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + "metadata": { + "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + }, + "id": "casual-spread", + "originalKey": "fe7437c5-8834-46cc-94b2-91782d91ee96", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916318830, + "executionStopTime": 1730916321328, + "serverExecutionDuration": 2274.8276960338, + "collapsed": false, + "requestMsgId": "fe7437c5-8834-46cc-94b2-91782d91ee96" }, - "outputs": [], "source": [ "for _ in range(10):\n", " # Produce a new generator run and attach it to experiment as a trial\n", @@ -648,13 +1163,65 @@ " trial.mark_completed()\n", "\n", " print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #0, suggested by Sobol.\nCompleted trial #1, suggested by Sobol.\nCompleted trial #2, suggested by Sobol.\nCompleted trial #3, suggested by Sobol.\nCompleted trial #4, suggested by Sobol.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #5, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #6, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #7, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #8, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #9, suggested by BoTorch.\n" + ] + } ] }, { "cell_type": "markdown", - "id": "circular-vermont", "metadata": { - "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + "metadata": { + "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + }, + "id": "circular-vermont", + "originalKey": "6a78ef13-fbaa-4cae-934b-d57f5807fe25", + "showInput": false, + "outputsInitialized": false, + "isAgentGenerated": false, + "language": "markdown" }, "source": [ "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" @@ -662,21 +1229,200 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "significant-particular", "metadata": { - "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + "metadata": { + "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + }, + "id": "significant-particular", + "originalKey": "b3160bc0-d5d1-45fa-bf62-4b9dd5778cac", + "outputsInitialized": true, + "isAgentGenerated": false, + "language": "python", + "executionStartTime": 1730916319576, + "executionStopTime": 1730916321368, + "serverExecutionDuration": 35.789265064523, + "collapsed": false, + "requestMsgId": "b3160bc0-d5d1-45fa-bf62-4b9dd5778cac" }, - "outputs": [], "source": [ "exp_to_df(experiment)" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[WARNING 11-06 10:05:21] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 26.922506 -2.244023 5.435609\n1 1 1_0 COMPLETED ... 74.072517 3.535081 10.528676\n2 2 2_0 COMPLETED ... 5.610080 8.741262 3.706691\n3 3 3_0 COMPLETED ... 56.657623 -0.069164 12.199905\n4 4 4_0 COMPLETED ... 27.932704 0.862014 1.306074\n5 5 5_0 COMPLETED ... 5.423062 10.000000 4.868411\n6 6 6_0 COMPLETED ... 9.250452 10.000000 0.299753\n7 7 7_0 COMPLETED ... 308.129096 -5.000000 0.000000\n8 8 8_0 COMPLETED ... 17.607633 0.778687 5.717932\n9 9 9_0 COMPLETED ... 132.986209 1.451895 15.000000\n\n[10 rows x 7 columns]", + "text/html": "
\n | trial_index | \narm_name | \ntrial_status | \ngeneration_method | \nbranin | \nx1 | \nx2 | \n
---|---|---|---|---|---|---|---|
0 | \n0 | \n0_0 | \nCOMPLETED | \nSobol | \n26.922506 | \n-2.244023 | \n5.435609 | \n
1 | \n1 | \n1_0 | \nCOMPLETED | \nSobol | \n74.072517 | \n3.535081 | \n10.528676 | \n
2 | \n2 | \n2_0 | \nCOMPLETED | \nSobol | \n5.610080 | \n8.741262 | \n3.706691 | \n
3 | \n3 | \n3_0 | \nCOMPLETED | \nSobol | \n56.657623 | \n-0.069164 | \n12.199905 | \n
4 | \n4 | \n4_0 | \nCOMPLETED | \nSobol | \n27.932704 | \n0.862014 | \n1.306074 | \n
5 | \n5 | \n5_0 | \nCOMPLETED | \nBoTorch | \n5.423062 | \n10.000000 | \n4.868411 | \n
6 | \n6 | \n6_0 | \nCOMPLETED | \nBoTorch | \n9.250452 | \n10.000000 | \n0.299753 | \n
7 | \n7 | \n7_0 | \nCOMPLETED | \nBoTorch | \n308.129096 | \n-5.000000 | \n0.000000 | \n
8 | \n8 | \n8_0 | \nCOMPLETED | \nBoTorch | \n17.607633 | \n0.778687 | \n5.717932 | \n
9 | \n9 | \n9_0 | \nCOMPLETED | \nBoTorch | \n132.986209 | \n1.451895 | \n15.000000 | \n