From f13ce470c6eb5e6d48450fbe86239734e91e5258 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 6 Nov 2024 14:54:13 -0800 Subject: [PATCH] Support per-metric model specification in MBM (#3009) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3009 Enables using different models for different metrics. * Adds ModelConfig dataclass to specify a single botorch model configuration * Adds a list of ModelConfigs to SurrogateSpec * Adds a dictionary mapping metric names to list of ModelConfigs to enable per-metric model specification * Lists of model configs are used to enable per-metric model selection across multiple ModelConfigs in a subsequent diff. Reviewed By: saitcakmak Differential Revision: D64793595 fbshipit-source-id: 57d6c5d7c7b4e525d2699b2482bbb3370431f85a --- ax/modelbridge/tests/test_registry.py | 15 +- ax/models/torch/botorch_modular/model.py | 11 +- ax/models/torch/botorch_modular/surrogate.py | 222 +++- ax/models/torch/botorch_modular/utils.py | 109 +- ax/models/torch/tests/test_model.py | 21 +- ax/models/torch/tests/test_surrogate.py | 171 ++- ax/models/torch/tests/test_utils.py | 2 +- ax/storage/json_store/decoder.py | 3 +- ax/storage/json_store/registry.py | 2 + tutorials/modular_botax.ipynb | 1175 +++++++++++++++--- 10 files changed, 1415 insertions(+), 316 deletions(-) diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 418c4bcf889..7f77bbde548 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -10,6 +10,7 @@ import torch from ax.core.observation import ObservationFeatures +from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.modelbridge.discrete import DiscreteModelBridge from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import ( @@ -99,7 +100,8 @@ def test_SAASBO(self) -> None: SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( - saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP + saasbo.model.surrogate.model_configs[0].botorch_model_class, + SaasFullyBayesianSingleTaskGP, ) @mock_botorch_optimize @@ -459,9 +461,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: self.assertIsInstance(mtgp, TorchModelBridge) self.assertIsInstance(mtgp.model, BoTorchModel) self.assertEqual(mtgp.model.acquisition_class, Acquisition) - self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) + is_moo = isinstance( + exp.optimization_config, MultiObjectiveOptimizationConfig + ) + if is_moo: + self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) + models = mtgp.model.surrogate.model.models + else: + models = [mtgp.model.surrogate.model] - for model in mtgp.model.surrogate.model.models: + for model in models: self.assertIsInstance( model, SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index a5e8f84aedc..03804f6f7de 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -28,6 +28,7 @@ check_outcome_dataset_match, choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, + ModelConfig, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -79,6 +80,8 @@ class SurrogateSpec: allow_batched_models: bool = True + model_configs: list[ModelConfig] = field(default_factory=list) + metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) outcomes: list[str] = field(default_factory=list) @@ -241,13 +244,19 @@ def fit( input_transform_options=spec.input_transform_options, outcome_transform_classes=spec.outcome_transform_classes, outcome_transform_options=spec.outcome_transform_options, + model_configs=spec.model_configs, + metric_to_model_configs=spec.metric_to_model_configs, allow_batched_models=spec.allow_batched_models, ) else: self._surrogate = Surrogate() # Fit the surrogate. - self.surrogate.model_options.update(additional_model_inputs) + for config in self.surrogate.model_configs: + config.model_options.update(additional_model_inputs) + for config_list in self.surrogate.metric_to_model_configs.values(): + for config in config_list: + config.model_options.update(additional_model_inputs) self.surrogate.fit( datasets=datasets, search_space_digest=search_space_digest, diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 0fab257ffca..80ecbd2a484 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -9,6 +9,7 @@ from __future__ import annotations import inspect +import warnings from collections import OrderedDict from collections.abc import Sequence from copy import deepcopy @@ -33,6 +34,7 @@ choose_model_class, convert_to_block_design, fit_botorch_model, + ModelConfig, subset_state_dict, use_model_list, ) @@ -50,11 +52,11 @@ _argparse_type_encoder, checked_cast, checked_cast_optional, + not_none, ) from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP -from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ( ChainedInputTransform, InputPerturbation, @@ -68,6 +70,7 @@ from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from pyre_extensions import none_throws from torch import Tensor NOT_YET_FIT_MSG = ( @@ -268,6 +271,19 @@ def _set_formatted_inputs( formatted_model_inputs[input_name] = input_class(**input_options) +def _raise_deprecation_warning(*args: Any, **kwargs: Any) -> None: + for k, v in kwargs.items(): + if (v is not None and k != "mll_class") or ( + k == "mll_class" and v is not ExactMarginalLogLikelihood + ): + warnings.warn( + f"{k} is deprecated and will be removed in a future version. " + f"Please specify {k} via `model_configs`.", + DeprecationWarning, + stacklevel=3, + ) + + class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under @@ -282,15 +298,20 @@ class Surrogate(Base): BoTorch model. If None is provided a model class will be selected (either one for all outcomes or a ModelList with separate models for each outcome) will be selected automatically based off the datasets at `construct` time. + This argument is deprecated in favor of model_configs. model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any additional kwargs passed into ``BoTorchModel.fit``. + This argument is deprecated in favor of model_configs. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. - mll_options: Dictionary of options / kwargs for the MLL. + This argument is deprecated in favor of model_configs. + mll_options: Dictionary of options / kwargs for the MLL. This argument is + deprecated in favor of model_configs. outcome_transform_classes: List of BoTorch outcome transforms classes. Passed down to the BoTorch ``Model``. Multiple outcome transforms can be chained - together using ``ChainedOutcomeTransform``. + together using ``ChainedOutcomeTransform``. This argument is deprecated in + favor of model_configs. outcome_transform_options: Outcome transform classes kwargs. The keys are class string names and the values are dictionaries of outcome transform kwargs. For example, @@ -299,10 +320,12 @@ class string names and the values are dictionaries of outcome transform outcome_transform_options = { "Standardize": {"m": 1}, ` - For more options see `botorch/models/transforms/outcome.py`. + For more options see `botorch/models/transforms/outcome.py`. This argument + is deprecated in favor of model_configs. input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. + This argument is deprecated in favor of model_configs. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform kwargs. For example, @@ -314,13 +337,22 @@ class string names and the values are dictionaries of input transform } ` For more input options see `botorch/models/transforms/input.py`. + This argument is deprecated in favor of model_configs. covar_module_class: Covariance module class. This gets initialized after parsing the ``covar_module_options`` in ``covar_module_argparse``, and gets passed to the model constructor as ``covar_module``. - covar_module_options: Covariance module kwargs. + This argument is deprecated in favor of model_configs. + covar_module_options: Covariance module kwargs. This argument is deprecated + in favor of model_configs. likelihood: ``Likelihood`` class. This gets initialized with ``likelihood_options`` and gets passed to the model constructor. - likelihood_options: Likelihood options. + This argument is deprecated in favor of model_configs. + likelihood_options: Likelihood options. This argument is deprecated in favor + of model_configs. + model_configs: List of model configs. Each model config is a specification of + a model. These should be used in favor of the above deprecated arguments. + metric_to_model_configs: Dictionary mapping metric names to a list of model + configs for that metric. allow_batched_models: Set to true to fit the models in a batch if supported. Set to false to fit individual models to each metric in a loop. """ @@ -339,22 +371,60 @@ def __init__( covar_module_options: dict[str, Any] | None = None, likelihood_class: type[Likelihood] | None = None, likelihood_options: dict[str, Any] | None = None, + model_configs: list[ModelConfig] | None = None, + metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> None: - self.botorch_model_class = botorch_model_class - # Copying model options to avoid mutating the original dict. - # We later update it with any additional kwargs passed into `BoTorchModel.fit`. - self.model_options: dict[str, Any] = (model_options or {}).copy() - self.mll_class = mll_class - self.mll_options: dict[str, Any] = mll_options or {} - self.outcome_transform_classes = outcome_transform_classes - self.outcome_transform_options: dict[str, Any] = outcome_transform_options or {} - self.input_transform_classes = input_transform_classes - self.input_transform_options: dict[str, Any] = input_transform_options or {} - self.covar_module_class = covar_module_class - self.covar_module_options: dict[str, Any] = covar_module_options or {} - self.likelihood_class = likelihood_class - self.likelihood_options: dict[str, Any] = likelihood_options or {} + _raise_deprecation_warning( + botorch_model_class=botorch_model_class, + model_options=model_options, + mll_class=mll_class, + mll_options=mll_options, + outcome_transform_classes=outcome_transform_classes, + outcome_transform_options=outcome_transform_options, + input_transform_classes=input_transform_classes, + input_transform_options=input_transform_options, + covar_module_class=covar_module_class, + covar_module_options=covar_module_options, + likelihood_class=likelihood_class, + likelihood_options=likelihood_options, + ) + model_configs = model_configs or [] + if len(model_configs) == 0: + model_config_kwargs = { + "botorch_model_class": botorch_model_class, + "model_options": (model_options or {}).copy(), + "mll_class": mll_class, + "mll_options": mll_options, + "outcome_transform_classes": outcome_transform_classes, + "outcome_transform_options": outcome_transform_options, + "input_transform_classes": input_transform_classes, + "input_transform_options": input_transform_options, + "covar_module_class": covar_module_class, + "covar_module_options": covar_module_options, + "likelihood_class": likelihood_class, + "likelihood_options": likelihood_options, + } + model_config_kwargs = { + k: v for k, v in model_config_kwargs.items() if v is not None + } + # pyre-fixme [6]: Incompatible parameter type [6]: In call + # `ModelConfig.__init__`, for 1st positional argument, expected + # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], + # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], + # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, + # Model]], Type[Likelihood], Type[Kernel]]`. + model_configs = [ModelConfig(**model_config_kwargs)] + self.model_configs: list[ModelConfig] = model_configs + self.metric_to_model_configs: dict[str, list[ModelConfig]] = ( + metric_to_model_configs or {} + ) + if len(self.model_configs) > 1 or any( + len(model_config) > 1 + for model_config in self.metric_to_model_configs.values() + ): + raise NotImplementedError("Only one model config per metric is supported.") + self.allow_batched_models = allow_batched_models # Store the last dataset used to fit the model for a given metric(s). # If the new dataset is identical, we will skip model fitting for that metric. @@ -377,10 +447,8 @@ def __init__( def __repr__(self) -> str: return ( f"<{self.__class__.__name__}" - f" botorch_model_class={self.botorch_model_class} " - f"mll_class={self.mll_class} " - f"outcome_transform_classes={self.outcome_transform_classes} " - f"input_transform_classes={self.input_transform_classes} " + f" model_configs={self.model_configs}," + f" metric_to_model_configs={self.metric_to_model_configs}>" ) @property @@ -404,9 +472,7 @@ def Xs(self) -> list[Tensor]: training_data = self.training_data Xs = [] for dataset in training_data: - if self.botorch_model_class == PairwiseGP and isinstance( - dataset, RankingDataset - ): + if isinstance(dataset, RankingDataset): # directly accessing the d-dim X tensor values # instead of the augmented 2*d-dim dataset.X from RankingDataset Xi = checked_cast(SliceContainer, dataset._X).values @@ -431,7 +497,8 @@ def _construct_model( self, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - botorch_model_class: type[Model], + model_config: ModelConfig, + default_botorch_model_class: type[Model], state_dict: OrderedDict[str, Tensor] | None, refit: bool, ) -> Model: @@ -446,19 +513,24 @@ def _construct_model( multi-output case, where training data is formatted with just one X and concatenated Ys). search_space_digest: Search space digest used to set up model arguments. - botorch_model_class: ``Model`` class to be used as the underlying - BoTorch model. + model_config: The model_config. + default_botorch_model_class: The default ``Model`` class to be used as the + underlying BoTorch model, if the model_config does not specify one. state_dict: Optional state dict to load. This should be subsetted for the current submodel being constructed. refit: Whether to re-optimize model parameters. """ outcome_names = tuple(dataset.outcome_names) + botorch_model_class = ( + model_config.botorch_model_class or default_botorch_model_class + ) if self._should_reuse_last_model( dataset=dataset, botorch_model_class=botorch_model_class ): return self._submodels[outcome_names] formatted_model_inputs = submodel_input_constructor( botorch_model_class, # Do not pass as kwarg since this is used to dispatch. + model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=self, @@ -469,7 +541,9 @@ def _construct_model( model.load_state_dict(state_dict) if state_dict is None or refit: fit_botorch_model( - model=model, mll_class=self.mll_class, mll_options=self.mll_options + model=model, + mll_class=model_config.mll_class, + mll_options=model_config.mll_options, ) self._submodels[outcome_names] = model self._last_datasets[outcome_names] = dataset @@ -539,14 +613,21 @@ def fit( # To determine whether to use ModelList under the hood, we need to check for # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. - botorch_model_class = self.botorch_model_class or choose_model_class( - datasets=datasets, search_space_digest=search_space_digest - ) - + if ( + len(self.model_configs) == 1 + and self.model_configs[0].botorch_model_class is None + ): + default_botorch_model_class = choose_model_class( + datasets=datasets, search_space_digest=search_space_digest + ) + else: + default_botorch_model_class = self.model_configs[0].botorch_model_class should_use_model_list = use_model_list( datasets=datasets, - botorch_model_class=botorch_model_class, + botorch_model_class=not_none(default_botorch_model_class), + model_configs=self.model_configs, allow_batched_models=self.allow_batched_models, + metric_to_model_configs=self.metric_to_model_configs, ) if not should_use_model_list and len(datasets) > 1: @@ -564,10 +645,30 @@ def fit( ) else: submodel_state_dict = state_dict + model_config = None + if len(self.metric_to_model_configs) > 0: + # if metric_to_model_configs is not empty, then + # we are using a model list and each dataset + # should have only one outcome. + if len(dataset.outcome_names) > 1: + raise ValueError( + "Each dataset should have only one outcome when " + "metric_to_model_configs is specified." + ) + model_config_list = self.metric_to_model_configs.get( + dataset.outcome_names[0] + ) + + # TODO: add support for automated model selection + if model_config_list is not None: + model_config = model_config_list[0] + if model_config is None: + model_config = self.model_configs[0] model = self._construct_model( dataset=dataset, search_space_digest=search_space_digest, - botorch_model_class=botorch_model_class, + model_config=model_config, + default_botorch_model_class=not_none(default_botorch_model_class), state_dict=submodel_state_dict, refit=refit, ) @@ -728,23 +829,12 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: as kwargs on reinstantiation. """ return { - "botorch_model_class": self.botorch_model_class, - "model_options": self.model_options, - "mll_class": self.mll_class, - "mll_options": self.mll_options, - "outcome_transform_classes": self.outcome_transform_classes, - "outcome_transform_options": self.outcome_transform_options, - "input_transform_classes": self.input_transform_classes, - "input_transform_options": self.input_transform_options, - "covar_module_class": self.covar_module_class, - "covar_module_options": self.covar_module_options, - "likelihood_class": self.likelihood_class, - "likelihood_options": self.likelihood_options, - "allow_batched_models": self.allow_batched_models, + "model_configs": self.model_configs, + "metric_to_model_configs": self.metric_to_model_configs, } def _extract_construct_input_transform_args( - self, search_space_digest: SearchSpaceDigest + self, model_config: ModelConfig, search_space_digest: SearchSpaceDigest ) -> tuple[Sequence[type[InputTransform]] | None, dict[str, dict[str, Any]]]: """ Extracts input transform classes and input transform options that will @@ -777,19 +867,19 @@ def _extract_construct_input_transform_args( InputPerturbation ] - if self.input_transform_classes is not None: + if model_config.input_transform_classes is not None: # TODO: Support mixing with user supplied transforms. raise NotImplementedError( "User supplied input transforms are not supported " "in robust optimization." ) else: - submodel_input_transform_classes = self.input_transform_classes - submodel_input_transform_options = self.input_transform_options + submodel_input_transform_classes = model_config.input_transform_classes + submodel_input_transform_options = model_config.input_transform_options return ( submodel_input_transform_classes, - submodel_input_transform_options, + none_throws(submodel_input_transform_options), ) @property @@ -811,6 +901,7 @@ def outcomes(self, value: list[str]) -> None: @submodel_input_constructor.register(Model) def _submodel_input_constructor_base( botorch_model_class: type[Model], + model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, @@ -819,6 +910,7 @@ def _submodel_input_constructor_base( Args: botorch_model_class: The BoTorch model class to instantiate. + model_config: The model config. dataset: The training data for the model. search_space_digest: Search space digest used to set up model arguments. surrogate: A reference to the surrogate that created the model. @@ -836,12 +928,12 @@ def _submodel_input_constructor_base( input_transform_classes, input_transform_options, ) = surrogate._extract_construct_input_transform_args( - search_space_digest=search_space_digest + model_config=model_config, search_space_digest=search_space_digest ) formatted_model_inputs = botorch_model_class.construct_inputs( training_data=dataset, - **surrogate.model_options, + **model_config.model_options, **model_kwargs_from_ss, ) @@ -851,14 +943,18 @@ def _submodel_input_constructor_base( inputs=[ ( "covar_module", - surrogate.covar_module_class, - surrogate.covar_module_options, + model_config.covar_module_class, + model_config.covar_module_options, + ), + ( + "likelihood", + model_config.likelihood_class, + model_config.likelihood_options, ), - ("likelihood", surrogate.likelihood_class, surrogate.likelihood_options), ( "outcome_transform", - surrogate.outcome_transform_classes, - surrogate.outcome_transform_options, + model_config.outcome_transform_classes, + model_config.outcome_transform_options, ), ( "input_transform", @@ -880,6 +976,7 @@ def _submodel_input_constructor_base( @submodel_input_constructor.register(MultiTaskGP) def _submodel_input_constructor_mtgp( botorch_model_class: type[Model], + model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, @@ -888,6 +985,7 @@ def _submodel_input_constructor_mtgp( raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, + model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=surrogate, diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 26a273ddd4e..bd207a08973 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -9,6 +9,7 @@ import warnings from collections import OrderedDict from collections.abc import Sequence +from dataclasses import dataclass, field from logging import Logger from typing import Any @@ -34,8 +35,13 @@ from botorch.models.model import Model, ModelList from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian +from gpytorch.kernels.kernel import Kernel +from gpytorch.likelihoods import Likelihood +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from pyre_extensions import none_throws from torch import Tensor @@ -44,20 +50,115 @@ logger: Logger = get_logger(__name__) +@dataclass +class ModelConfig: + """Configuration for the BoTorch Model used in Surrogate. + + Args: + botorch_model_class: ``Model`` class to be used as the underlying + BoTorch model. If None is provided a model class will be selected (either + one for all outcomes or a ModelList with separate models for each outcome) + will be selected automatically based off the datasets at `construct` time. + model_options: Dictionary of options / kwargs for the BoTorch + ``Model`` constructed during ``Surrogate.fit``. + Note that the corresponding attribute will later be updated to include any + additional kwargs passed into ``BoTorchModel.fit``. + mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. + This argument is deprecated in favor of model_configs. + mll_options: Dictionary of options / kwargs for the MLL. + outcome_transform_classes: List of BoTorch outcome transforms classes. Passed + down to the BoTorch ``Model``. Multiple outcome transforms can be chained + together using ``ChainedOutcomeTransform``. + outcome_transform_options: Outcome transform classes kwargs. The keys are + class string names and the values are dictionaries of outcome transform + kwargs. For example, + ` + outcome_transform_classes = [Standardize] + outcome_transform_options = { + "Standardize": {"m": 1}, + ` + For more options see `botorch/models/transforms/outcome.py`. + input_transform_classes: List of BoTorch input transforms classes. + Passed down to the BoTorch ``Model``. Multiple input transforms + will be chained together using ``ChainedInputTransform``. + input_transform_options: Input transform classes kwargs. The keys are + class string names and the values are dictionaries of input transform + kwargs. For example, + ` + input_transform_classes = [Normalize, Round] + input_transform_options = { + "Normalize": {"d": 3}, + "Round": {"integer_indices": [0], "categorical_features": {1: 2}}, + } + ` + For more input options see `botorch/models/transforms/input.py`. + covar_module_class: Covariance module class. This gets initialized after + parsing the ``covar_module_options`` in ``covar_module_argparse``, + and gets passed to the model constructor as ``covar_module``. + covar_module_options: Covariance module kwargs. + in favor of model_configs. + likelihood: ``Likelihood`` class. This gets initialized with + ``likelihood_options`` and gets passed to the model constructor. + This argument is deprecated in favor of model_configs. + likelihood_options: Likelihood options. + """ + + botorch_model_class: type[Model] | None = None + model_options: dict[str, Any] = field(default_factory=dict) + mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood + mll_options: dict[str, Any] = field(default_factory=dict) + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = field( + default_factory=dict + ) + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict) + covar_module_class: type[Kernel] | None = None + covar_module_options: dict[str, Any] = field(default_factory=dict) + likelihood_class: type[Likelihood] | None = None + likelihood_options: dict[str, Any] = field(default_factory=dict) + + def use_model_list( datasets: Sequence[SupervisedDataset], botorch_model_class: type[Model], + model_configs: list[ModelConfig] | None = None, + metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> bool: - if issubclass(botorch_model_class, MultiTaskGP): - # We currently always wrap multi-task models into `ModelListGP`. + model_configs = model_configs or [] + metric_to_model_configs = metric_to_model_configs or {} + if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1: + # There is only one outcome, so we can use a single model. + return False + elif ( + len(model_configs) > 1 + or len(metric_to_model_configs) > 0 + or any(len(model_config) for model_config in metric_to_model_configs.values()) + ): + # There are multiple outcomes and outcomes might be modeled with different + # models return True - elif issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): + # Otherwise, the same model class is used for all outcomes. + # Determine what the model class is. + if len(model_configs) > 0: + botorch_model_class = ( + model_configs[0].botorch_model_class or botorch_model_class + ) + if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): # SAAS models do not support multiple outcomes. # Use model list if there are multiple outcomes. return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 + elif issubclass(botorch_model_class, MultiTaskGP): + # We wrap multi-task models into `ModelListGP` when there are + # multiple outcomes. + return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 elif len(datasets) == 1: - # Just one outcome, can use single model. + # This method is called before multiple datasets are merged into + # one if using a batched model. If there is one dataset here, + # there should be a reason that a single model should be used: + # e.g. a contextual model, where we want to jointly model the metric + # each context (and context-level metrics are different outcomes). return False elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all( torch.equal(datasets[0].X, ds.X) for ds in datasets[1:] diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index c8cf1ddba7a..389b3072175 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -27,6 +27,7 @@ from ax.models.torch.botorch_modular.utils import ( choose_model_class, construct_acquisition_and_optimizer_options, + ModelConfig, ) from ax.models.torch.utils import _filter_X_observed from ax.models.torch_base import TorchOptConfig @@ -327,7 +328,21 @@ def test_fit(self, mock_fit: Mock) -> None: mock_fit.assert_called_with( dataset=self.block_design_training_data[0], search_space_digest=self.mf_search_space_digest, - botorch_model_class=SingleTaskMultiFidelityGP, + model_config=ModelConfig( + botorch_model_class=None, + model_options={}, + mll_class=ExactMarginalLogLikelihood, + mll_options={}, + input_transform_classes=None, + input_transform_options={}, + outcome_transform_classes=None, + outcome_transform_options={}, + covar_module_class=None, + covar_module_options={}, + likelihood_class=None, + likelihood_options={}, + ), + default_botorch_model_class=SingleTaskMultiFidelityGP, state_dict=None, refit=True, ) @@ -727,6 +742,8 @@ def test_surrogate_model_options_propagation( input_transform_options=None, outcome_transform_classes=None, outcome_transform_options=None, + model_configs=[], + metric_to_model_configs={}, allow_batched_models=True, ) @@ -755,6 +772,8 @@ def test_surrogate_options_propagation( input_transform_options=None, outcome_transform_classes=None, outcome_transform_options=None, + model_configs=[], + metric_to_model_configs={}, allow_batched_models=False, ) diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 2d6b12b7ef0..6bbd600df90 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -18,8 +18,13 @@ from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition +from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate -from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model +from ax.models.torch.botorch_modular.utils import ( + choose_model_class, + fit_botorch_model, + ModelConfig, +) from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast @@ -37,7 +42,7 @@ from botorch.models.transforms.outcome import Standardize from botorch.utils.datasets import SupervisedDataset from gpytorch.constraints import GreaterThan, Interval -from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood from pyre_extensions import assert_is_instance, none_throws @@ -216,8 +221,10 @@ def _get_surrogate( def test_init(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) - self.assertEqual(surrogate.botorch_model_class, botorch_model_class) - self.assertEqual(surrogate.mll_class, self.mll_class) + self.assertEqual( + surrogate.model_configs[0].botorch_model_class, botorch_model_class + ) + self.assertEqual(surrogate.model_configs[0].mll_class, self.mll_class) self.assertTrue(surrogate.allow_batched_models) # True by default def test_clone_reset(self) -> None: @@ -432,7 +439,8 @@ def test_construct_model(self) -> None: Surrogate()._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=Model, + model_config=ModelConfig(), + default_botorch_model_class=Model, state_dict=None, refit=True, ) @@ -446,7 +454,8 @@ def test_construct_model(self) -> None: model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=botorch_model_class, + model_config=surrogate.model_configs[0], + default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, ) @@ -471,7 +480,8 @@ def test_construct_model(self) -> None: new_model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=botorch_model_class, + model_config=surrogate.model_configs[0], + default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, ) @@ -485,7 +495,8 @@ def test_construct_model(self) -> None: surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=SingleTaskGPWithDifferentConstructor, + model_config=ModelConfig(), + default_botorch_model_class=SingleTaskGPWithDifferentConstructor, state_dict=None, refit=True, ) @@ -497,14 +508,18 @@ def test_construct_model(self) -> None: ) @mock_botorch_optimize - def test_construct_custom_model(self) -> None: + def test_construct_custom_model(self, use_model_config: bool = False) -> None: # Test error for unsupported covar_module and likelihood. - surrogate = Surrogate( - botorch_model_class=SingleTaskGPWithDifferentConstructor, - mll_class=self.mll_class, - covar_module_class=RBFKernel, - likelihood_class=FixedNoiseGaussianLikelihood, - ) + model_config_kwargs: dict[str, Any] = { + "botorch_model_class": SingleTaskGPWithDifferentConstructor, + "mll_class": self.mll_class, + "covar_module_class": RBFKernel, + "likelihood_class": FixedNoiseGaussianLikelihood, + } + if use_model_config: + surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + else: + surrogate = Surrogate(**model_config_kwargs) with self.assertRaisesRegex(UserInputError, "does not support"): surrogate.fit( self.training_data, @@ -512,14 +527,18 @@ def test_construct_custom_model(self) -> None: ) # Pass custom options to a SingleTaskGP and make sure they are used noise_constraint = Interval(1e-6, 1e-1) - surrogate = Surrogate( - botorch_model_class=SingleTaskGP, - mll_class=LeaveOneOutPseudoLikelihood, - covar_module_class=RBFKernel, - covar_module_options={"ard_num_dims": 3}, - likelihood_class=GaussianLikelihood, - likelihood_options={"noise_constraint": noise_constraint}, - ) + model_config_kwargs = { + "botorch_model_class": SingleTaskGP, + "mll_class": LeaveOneOutPseudoLikelihood, + "covar_module_class": RBFKernel, + "covar_module_options": {"ard_num_dims": 3}, + "likelihood_class": GaussianLikelihood, + "likelihood_options": {"noise_constraint": noise_constraint}, + } + if use_model_config: + surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + else: + surrogate = Surrogate(**model_config_kwargs) surrogate.fit( self.training_data, search_space_digest=self.search_space_digest, @@ -532,10 +551,65 @@ def test_construct_custom_model(self) -> None: model.likelihood.noise_covar.raw_noise_constraint.__dict__, noise_constraint.__dict__, ) - self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood) + self.assertEqual( + surrogate.model_configs[0].mll_class, LeaveOneOutPseudoLikelihood + ) self.assertEqual(type(model.covar_module), RBFKernel) self.assertEqual(model.covar_module.ard_num_dims, 3) + def test_construct_custom_model_with_config(self) -> None: + self.test_construct_custom_model(use_model_config=True) + + def test_construct_model_with_metric_to_model_configs(self) -> None: + surrogate = Surrogate( + metric_to_model_configs={ + "metric": [ModelConfig()], + "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], + }, + model_configs=[ModelConfig(covar_module_class=LinearKernel)], + ) + training_data = self.training_data + [ + SupervisedDataset( + X=self.Xs[0], + # Note: using 1d Y does not match the 2d TorchOptConfig + Y=self.Ys[0], + feature_names=self.feature_names, + outcome_names=[f"metric{i}"], + ) + for i in range(2, 5) + ] + surrogate.fit( + datasets=training_data, search_space_digest=self.search_space_digest + ) + # test model follows metric_to_model_configs for + # first two metrics + self.assertIsInstance(surrogate.model, ModelListGP) + submodels = surrogate.model.models + self.assertEqual(len(submodels), 4) + for m in submodels: + self.assertIsInstance(m, SingleTaskGP) + self.assertIsInstance(surrogate.model.models[1].covar_module, ScaleKernel) + self.assertIsInstance( + surrogate.model.models[1].covar_module.base_kernel, MaternKernel + ) + self.assertIsInstance(surrogate.model.models[0].covar_module, RBFKernel) + # test model use model_configs for the third metric + self.assertIsInstance(surrogate.model.models[2].covar_module, LinearKernel) + + def test_multiple_model_configs_error(self) -> None: + with self.assertRaisesRegex( + NotImplementedError, "Only one model config per metric is supported." + ): + Surrogate( + model_configs=[ModelConfig(), ModelConfig()], + ) + with self.assertRaisesRegex( + NotImplementedError, "Only one model config per metric is supported." + ): + Surrogate( + metric_to_model_configs={"metric": [ModelConfig(), ModelConfig()]}, + ) + @mock_botorch_optimize @patch(f"{SURROGATE_PATH}.predict_from_model") def test_predict(self, mock_predict: Mock) -> None: @@ -647,7 +721,8 @@ def test_serialize_attributes_as_kwargs(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) expected = { - k: v for k, v in surrogate.__dict__.items() if not k.startswith("_") + "model_configs": surrogate.model_configs, + "metric_to_model_configs": surrogate.metric_to_model_configs, } self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected) @@ -677,7 +752,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.input_transform_classes = [Normalize] + surrogate.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.training_data, @@ -830,11 +905,12 @@ def setUp(self) -> None: ) def test_init(self) -> None: + model_config = self.surrogate.model_configs[0] self.assertEqual( - [self.surrogate.botorch_model_class] * 2, + [model_config.botorch_model_class] * 2, [*self.botorch_submodel_class_per_outcome.values()], ) - self.assertEqual(self.surrogate.mll_class, self.mll_class) + self.assertEqual(model_config.mll_class, self.mll_class) with self.assertRaisesRegex( ValueError, "BoTorch `Model` has not yet been constructed" ): @@ -849,7 +925,7 @@ def test_init(self) -> None: def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock ) -> None: - self.surrogate.model_options.update({"output_tasks": [2]}) + self.surrogate.model_configs[0].model_options.update({"output_tasks": [2]}) for fixed_noise in (False, True): mock_fit.reset_mock() mock_MTGP_construct_inputs.reset_mock() @@ -940,10 +1016,14 @@ def test_fit( self.assertIsNone(surrogate._model) # Should instantiate mll and `fit_gpytorch_mll` when `state_dict` # is `None`. - # pyre-ignore[6]: Incompatible parameter type: In call - # `issubclass`, for 1st positional argument, expected - # `Type[typing.Any]` but got `Optional[Type[Model]]`. - is_mtgp = issubclass(surrogate.botorch_model_class, MultiTaskGP) + + is_mtgp = issubclass( + # pyre-ignore[6]: Incompatible parameter type: In call + # `issubclass`, for 1st positional argument, expected + # `Type[typing.Any]` but got `Optional[Type[Model]]`. + surrogate.model_configs[0].botorch_model_class, + MultiTaskGP, + ) search_space_digest = ( self.multi_task_search_space_digest if is_mtgp @@ -1097,25 +1177,6 @@ def test_with_botorch_transforms(self) -> None: ) ) - def test_serialize_attributes_as_kwargs(self) -> None: - # TODO[mpolson64] Reimplement this when serialization has been sorted out - pass - # expected = self.surrogate.__dict__ - # # The two attributes below don't need to be saved as part of state, - # # so we remove them from the expected dict. - # for attr_name in ( - # "botorch_model_class", - # "model_options", - # "covar_module_class", - # "covar_module_options", - # "likelihood_class", - # "likelihood_options", - # "outcome_transform", - # "input_transform", - # ): - # expected.pop(attr_name) - # self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(), expected) - @mock_botorch_optimize def test_construct_custom_model(self) -> None: noise_constraint = Interval(1e-4, 10.0) @@ -1144,7 +1205,9 @@ def test_construct_custom_model(self) -> None: ) models = checked_cast(ModelListGP, surrogate._model).models self.assertEqual(len(models), 2) - self.assertEqual(surrogate.mll_class, ExactMarginalLogLikelihood) + self.assertEqual( + surrogate.model_configs[0].mll_class, ExactMarginalLogLikelihood + ) # Make sure we properly copied the transforms. self.assertNotEqual( id(models[0].input_transform), id(models[1].input_transform) @@ -1197,7 +1260,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.input_transform_classes = [Normalize] + surrogate.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.supervised_training_data, diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 19d263f5fa1..19e05552b68 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -306,7 +306,7 @@ def test_use_model_list(self) -> None: botorch_model_class=SingleTaskGP, ) ) - self.assertTrue( + self.assertFalse( use_model_list( datasets=self.supervised_datasets, botorch_model_class=MultiTaskGP ) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index d39e46801dd..5b31efb70f8 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -46,6 +46,7 @@ ) from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.utils import ModelConfig from ax.storage.json_store.decoders import ( batch_trial_from_json, botorch_component_from_json, @@ -229,7 +230,7 @@ def object_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - elif _class in (SurrogateSpec, Surrogate): + elif _class in (SurrogateSpec, Surrogate, ModelConfig): if "input_transform" in object_json: ( input_transform_classes_json, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index a798c30b1fe..02498d03669 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -98,6 +98,7 @@ from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.utils import ModelConfig from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType @@ -330,6 +331,7 @@ "AuxiliaryExperimentCheck": AuxiliaryExperimentCheck, "Models": Models, "ModelRegistryBase": ModelRegistryBase, + "ModelConfig": ModelConfig, "ModelSpec": ModelSpec, "MultiObjective": MultiObjective, "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index 275530f2f3c..da21088c63c 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -1,13 +1,43 @@ { + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5, "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "about-preview", "metadata": { - "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + "collapsed": false, + "executionStartTime": 1730904956474, + "executionStopTime": 1730904963470, + "id": "about-preview", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + }, + "originalKey": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7", + "outputsInitialized": true, + "requestMsgId": "b3373e27-c3fa-41de-bf4b-3adb0f0571e7", + "serverExecutionDuration": 4351.2808320811 }, - "outputs": [], "source": [ "from typing import Any, Dict, Optional, Tuple, Type\n", "\n", @@ -39,13 +69,37 @@ "# BoTorch components\n", "from botorch.models.model import Model\n", "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "I1106 065557.352 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "I1106 065557.353 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" + ] + } ] }, { "cell_type": "markdown", - "id": "northern-affairs", "metadata": { - "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + "id": "northern-affairs", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + }, + "originalKey": "c9a665ca-497e-4d7c-bbb5-1b9f8d1d311c", + "outputsInitialized": false, + "showInput": false }, "source": [ "# Setup and Usage of BoTorch Models in Ax\n", @@ -70,9 +124,16 @@ }, { "cell_type": "markdown", - "id": "pending-support", "metadata": { - "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + "id": "pending-support", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + }, + "originalKey": "4706d02e-6b3f-4161-9e08-f5a31328b1d1", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 1. Quick-start example\n", @@ -82,25 +143,53 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "parental-sending", "metadata": { - "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + "collapsed": false, + "executionStartTime": 1730904958719, + "executionStopTime": 1730904963489, + "id": "parental-sending", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + }, + "originalKey": "146a9a1b-52e6-4d76-9fc5-79025b392673", + "outputsInitialized": true, + "requestMsgId": "146a9a1b-52e6-4d76-9fc5-79025b392673", + "serverExecutionDuration": 42.191333021037 }, - "outputs": [], "source": [ "experiment = get_branin_experiment(with_trial=True)\n", "data = get_branin_data(trials=[experiment.trials[0]])" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:56:01] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "rough-somerset", "metadata": { - "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + "collapsed": false, + "executionStartTime": 1730904959343, + "executionStopTime": 1730904964892, + "id": "rough-somerset", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + }, + "originalKey": "aa532754-01ad-4441-84c1-2ac7f54ecf1e", + "outputsInitialized": true, + "requestMsgId": "aa532754-01ad-4441-84c1-2ac7f54ecf1e", + "serverExecutionDuration": 870.78339292202 }, - "outputs": [], "source": [ "# `Models` automatically selects a model + model bridge combination.\n", "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", @@ -110,13 +199,30 @@ " surrogate=Surrogate(SingleTaskGP), # Optional, will use default if unspecified\n", " botorch_acqf_class=qLogNoisyExpectedImprovement, # Optional, will use default if unspecified\n", ")" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:56:01] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + } ] }, { "cell_type": "markdown", - "id": "hairy-wiring", "metadata": { - "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + "id": "hairy-wiring", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + }, + "originalKey": "46f5c2c7-400d-4d8d-b0b9-a241657b173f", + "outputsInitialized": false, + "showInput": false }, "source": [ "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." @@ -124,22 +230,49 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "consecutive-summary", "metadata": { - "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + "collapsed": false, + "executionStartTime": 1730904961333, + "executionStopTime": 1730904964907, + "id": "consecutive-summary", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + }, + "originalKey": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696", + "outputsInitialized": true, + "requestMsgId": "c0051dd9-bf05-42bc-b4c3-ae5b99eba696", + "serverExecutionDuration": 284.31268292479 }, - "outputs": [], "source": [ "generator_run = model_bridge_with_GPEI.gen(n=1)\n", "generator_run.arms[0]" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "Arm(parameters={'x1': -5.0, 'x2': 0.0})" + }, + "metadata": {}, + "execution_count": 4 + } ] }, { "cell_type": "markdown", - "id": "diverse-richards", "metadata": { - "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + "id": "diverse-richards", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + }, + "originalKey": "804bac30-db07-4444-98a2-7a5f05007495", + "outputsInitialized": false, + "showInput": false }, "source": [ "-----\n", @@ -151,9 +284,16 @@ }, { "cell_type": "markdown", - "id": "grand-committee", "metadata": { - "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + "id": "grand-committee", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + }, + "originalKey": "31b54ce5-2590-4617-b10c-d24ed3cce51d", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 2. BoTorchModel = Surrogate + Acquisition\n", @@ -163,9 +303,16 @@ }, { "cell_type": "markdown", - "id": "thousand-blanket", "metadata": { - "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + "id": "thousand-blanket", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + }, + "originalKey": "4a4e006e-07fa-4d63-8b9a-31b67075e40e", + "outputsInitialized": false, + "showInput": false }, "source": [ "### 2A. Example that uses defaults and requires no options\n", @@ -175,30 +322,52 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "changing-xerox", "metadata": { - "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + "collapsed": false, + "executionStartTime": 1730905022012, + "executionStopTime": 1730905022269, + "id": "changing-xerox", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + }, + "originalKey": "509e30d7-dc32-4190-836f-f221cacbff31", + "outputsInitialized": true, + "requestMsgId": "509e30d7-dc32-4190-836f-f221cacbff31", + "serverExecutionDuration": 1.972567057237 }, - "outputs": [], "source": [ + "from ax.models.torch.botorch_modular.utils import ModelConfig\n", + "\n", "# The surrogate is not specified, so it will be auto-selected\n", "# during `model.fit`.\n", "GPEI_model = BoTorchModel(botorch_acqf_class=qLogExpectedImprovement)\n", "\n", "# The acquisition class is not specified, so it will be\n", "# auto-selected during `model.gen` or `model.evaluate_acquisition`\n", - "GPEI_model = BoTorchModel(surrogate=Surrogate(SingleTaskGP))\n", + "GPEI_model = BoTorchModel(\n", + " surrogate=Surrogate(model_configs=[ModelConfig(SingleTaskGP)])\n", + ")\n", "\n", "# Both the surrogate and acquisition class will be auto-selected.\n", "GPEI_model = BoTorchModel()" - ] + ], + "execution_count": 7, + "outputs": [] }, { "cell_type": "markdown", - "id": "lovely-mechanics", "metadata": { - "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + "id": "lovely-mechanics", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + }, + "originalKey": "7b9fae38-fe5d-4e5b-8b5f-2953c1ef09d2", + "outputsInitialized": false, + "showInput": false }, "source": [ "### 2B. Example with all the options\n", @@ -207,23 +376,36 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "twenty-greek", "metadata": { - "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + "collapsed": false, + "executionStartTime": 1730905023369, + "executionStopTime": 1730905023622, + "id": "twenty-greek", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + }, + "originalKey": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03", + "outputsInitialized": true, + "requestMsgId": "a63a4a66-07c7-42b8-8c6b-fda19d9c7f03", + "serverExecutionDuration": 1.9794970285147 }, - "outputs": [], "source": [ "model = BoTorchModel(\n", " # Optional `Surrogate` specification to use instead of default\n", " surrogate=Surrogate(\n", - " # BoTorch `Model` type\n", - " botorch_model_class=SingleTaskGP,\n", - " # Optional, MLL class with which to optimize model parameters\n", - " mll_class=ExactMarginalLogLikelihood,\n", - " # Optional, dictionary of keyword arguments to underlying\n", - " # BoTorch `Model` constructor\n", - " model_options={},\n", + " model_configs=[\n", + " ModelConfig(\n", + " # BoTorch `Model` type\n", + " botorch_model_class=SingleTaskGP,\n", + " # Optional, MLL class with which to optimize model parameters\n", + " mll_class=ExactMarginalLogLikelihood,\n", + " # Optional, dictionary of keyword arguments to underlying\n", + " # BoTorch `Model` constructor\n", + " model_options={},\n", + " )\n", + " ]\n", " ),\n", " # Optional BoTorch `AcquisitionFunction` to use instead of default\n", " botorch_acqf_class=qLogExpectedImprovement,\n", @@ -238,13 +420,22 @@ " refit_on_cv=False,\n", " warm_start_refit=True,\n", ")" - ] + ], + "execution_count": 8, + "outputs": [] }, { "cell_type": "markdown", - "id": "fourth-material", "metadata": { - "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + "id": "fourth-material", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + }, + "originalKey": "7140bb19-09b4-4abe-951d-53902ae07833", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 2C. `Surrogate` and `Acquisition` Q&A\n", @@ -258,9 +449,16 @@ }, { "cell_type": "markdown", - "id": "violent-course", "metadata": { - "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + "id": "violent-course", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + }, + "originalKey": "71f92895-874d-4fc7-ae87-a5519b18d1a0", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" @@ -268,11 +466,18 @@ }, { "cell_type": "markdown", - "id": "unlike-football", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "id": "unlike-football", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "showInput": false + }, + "originalKey": "4af8afa2-5056-46be-b7b9-428127e668cc", + "outputsInitialized": false, "showInput": false }, "source": [ @@ -286,16 +491,24 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "dynamic-university", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + "collapsed": false, + "executionStartTime": 1730905050779, + "executionStopTime": 1730905051022, + "id": "dynamic-university", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + }, + "originalKey": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7", + "outputsInitialized": true, + "requestMsgId": "1830b3f5-fd8e-4151-97d9-8a2aaa6885f7", + "serverExecutionDuration": 2.5736939860508 }, - "outputs": [], "source": [ - "from botorch.models.model import Model\n", "from botorch.utils.datasets import SupervisedDataset\n", "\n", "\n", @@ -318,17 +531,30 @@ "\n", "\n", "surrogate = Surrogate(\n", - " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", - " # Optional dict of additional keyword arguments to `MyModelClass`\n", - " model_options={},\n", + " model_configs=[\n", + " ModelConfig(\n", + " botorch_model_class=MyModelClass, # Must implement `construct_inputs`\n", + " # Optional dict of additional keyword arguments to `MyModelClass`\n", + " model_options={},\n", + " )\n", + " ],\n", ")" - ] + ], + "execution_count": 9, + "outputs": [] }, { "cell_type": "markdown", - "id": "otherwise-context", "metadata": { - "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + "id": "otherwise-context", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + }, + "originalKey": "5a27fd2c-4c4c-41fe-a634-f6d0ec4f1666", + "outputsInitialized": false, + "showInput": false }, "source": [ "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." @@ -336,9 +562,16 @@ }, { "cell_type": "markdown", - "id": "northern-invite", "metadata": { - "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + "id": "northern-invite", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + }, + "originalKey": "df06d02b-95cb-4d34-aac6-773231f1a129", + "outputsInitialized": false, + "showInput": false }, "source": [ "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" @@ -346,11 +579,18 @@ }, { "cell_type": "markdown", - "id": "surrounded-denial", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "id": "surrounded-denial", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "showInput": false + }, + "originalKey": "d4861847-b757-4fcd-9f35-ba258080812c", + "outputsInitialized": false, "showInput": false }, "source": [ @@ -363,14 +603,23 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "interested-search", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + "collapsed": false, + "executionStartTime": 1730905067704, + "executionStopTime": 1730905067939, + "id": "interested-search", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + }, + "originalKey": "075243f0-c2d6-46fd-9a18-5b77af258abf", + "outputsInitialized": true, + "requestMsgId": "075243f0-c2d6-46fd-9a18-5b77af258abf", + "serverExecutionDuration": 5.0081580411643 }, - "outputs": [], "source": [ "from botorch.acquisition.acquisition import AcquisitionFunction\n", "from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict\n", @@ -405,13 +654,31 @@ " \"optimizer_options\": {\"sequential\": False},\n", " },\n", ")" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "BoTorchModel" + }, + "metadata": {}, + "execution_count": 10 + } ] }, { "cell_type": "markdown", - "id": "metallic-imaging", "metadata": { - "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + "id": "metallic-imaging", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + }, + "originalKey": "b057722d-b8ca-47dd-b2c8-1ff4a71c4863", + "outputsInitialized": false, + "showInput": false }, "source": [ "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." @@ -419,9 +686,16 @@ }, { "cell_type": "markdown", - "id": "descending-australian", "metadata": { - "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + "id": "descending-australian", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + }, + "originalKey": "a7406f13-1468-487d-ac5e-7d2a45394850", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 4. Using `Models.BOTORCH_MODULAR` \n", @@ -433,49 +707,123 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "attached-border", "metadata": { - "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + "collapsed": false, + "executionStartTime": 1730905070804, + "executionStopTime": 1730905071313, + "id": "attached-border", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + }, + "originalKey": "980d8513-8607-4099-8cfe-7c8d7bf5afe9", + "outputsInitialized": true, + "requestMsgId": "980d8513-8607-4099-8cfe-7c8d7bf5afe9", + "serverExecutionDuration": 262.54303695168 }, - "outputs": [], "source": [ "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", " data=data,\n", ")\n", "model_bridge_with_GPEI.gen(1)" + ], + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:57:51] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "GeneratorRun(1 arms, total weight 1.0)" + }, + "metadata": {}, + "execution_count": 11 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "powerful-gamma", "metadata": { - "originalKey": "89930a31-e058-434b-b587-181931e247b6" + "collapsed": false, + "executionStartTime": 1730905071744, + "executionStopTime": 1730905071981, + "id": "powerful-gamma", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "89930a31-e058-434b-b587-181931e247b6" + }, + "originalKey": "6ec047f0-a75e-4733-b4d7-20045627f0b2", + "outputsInitialized": true, + "requestMsgId": "6ec047f0-a75e-4733-b4d7-20045627f0b2", + "serverExecutionDuration": 2.8772989753634 }, - "outputs": [], "source": [ "model_bridge_with_GPEI.model.botorch_acqf_class" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.acquisition.logei.qLogNoisyExpectedImprovement" + }, + "metadata": {}, + "execution_count": 12 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "improved-replication", "metadata": { - "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + "collapsed": false, + "executionStartTime": 1730905076319, + "executionStopTime": 1730905076580, + "id": "improved-replication", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + }, + "originalKey": "68caa729-792d-4692-bc4d-4c0b8d03e022", + "outputsInitialized": true, + "requestMsgId": "68caa729-792d-4692-bc4d-4c0b8d03e022", + "serverExecutionDuration": 3.2039729412645 }, - "outputs": [], "source": [ - "model_bridge_with_GPEI.model.surrogate.botorch_model_class" + "type(model_bridge_with_GPEI.model.surrogate.model)" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.models.gp_regression.SingleTaskGP" + }, + "metadata": {}, + "execution_count": 13 + } ] }, { "cell_type": "markdown", - "id": "connected-sheet", "metadata": { - "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + "id": "connected-sheet", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + }, + "originalKey": "f5c0adbd-00a6-428d-810f-1e7ed0954b08", + "outputsInitialized": false, + "showInput": false }, "source": [ "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" @@ -483,12 +831,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "documentary-jurisdiction", "metadata": { - "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + "collapsed": false, + "executionStartTime": 1730905078786, + "executionStopTime": 1730905079551, + "id": "documentary-jurisdiction", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + }, + "originalKey": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be", + "outputsInitialized": true, + "requestMsgId": "8ab2462c-c927-4a7c-95cb-c281b9b7f1be", + "serverExecutionDuration": 512.32317101676 }, - "outputs": [], "source": [ "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", " experiment=get_branin_experiment_with_multi_objective(\n", @@ -497,37 +854,116 @@ " data=get_branin_data_multi_objective(),\n", ")\n", "model_bridge_with_EHVI.gen(1)" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:57:59] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:57:59] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "GeneratorRun(1 arms, total weight 1.0)" + }, + "metadata": {}, + "execution_count": 14 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "changed-maintenance", "metadata": { - "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + "collapsed": false, + "executionStartTime": 1730905079324, + "executionStopTime": 1730905079651, + "id": "changed-maintenance", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + }, + "originalKey": "87247bf1-04f2-4a8a-92f6-18174d70cbb7", + "outputsInitialized": true, + "requestMsgId": "87247bf1-04f2-4a8a-92f6-18174d70cbb7", + "serverExecutionDuration": 3.0141109600663 }, - "outputs": [], "source": [ "model_bridge_with_EHVI.model.botorch_acqf_class" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement" + }, + "metadata": {}, + "execution_count": 15 + } ] }, { "cell_type": "code", - "execution_count": null, - "id": "operating-shelf", "metadata": { - "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + "collapsed": false, + "executionStartTime": 1730905087115, + "executionStopTime": 1730905087362, + "id": "operating-shelf", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + }, + "originalKey": "22fefbbf-68a9-4d5d-ade9-9df425995c3b", + "outputsInitialized": true, + "requestMsgId": "22fefbbf-68a9-4d5d-ade9-9df425995c3b", + "serverExecutionDuration": 3.2659249845892 }, - "outputs": [], "source": [ - "model_bridge_with_EHVI.model.surrogate.botorch_model_class" + "type(model_bridge_with_EHVI.model.surrogate.model)" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "botorch.models.gp_regression.SingleTaskGP" + }, + "metadata": {}, + "execution_count": 16 + } ] }, { "cell_type": "markdown", - "id": "fatal-butterfly", "metadata": { - "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + "id": "fatal-butterfly", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + }, + "originalKey": "3ad7c4a7-fe19-44ad-938d-1be4f8b09bfb", + "outputsInitialized": false, + "showInput": false }, "source": [ "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " @@ -535,9 +971,16 @@ }, { "cell_type": "markdown", - "id": "hearing-interface", "metadata": { - "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + "id": "hearing-interface", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + }, + "originalKey": "44adf1ce-6d3e-455d-b53c-32d3c42a843f", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 5. Utilizing `BoTorchModel` in generation strategies\n", @@ -549,12 +992,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "received-registration", "metadata": { - "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + "collapsed": false, + "executionStartTime": 1730905104120, + "executionStopTime": 1730905104377, + "id": "received-registration", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + }, + "originalKey": "d0303c24-98bb-4c89-87cb-fa32ff498bd4", + "outputsInitialized": true, + "requestMsgId": "d0303c24-98bb-4c89-87cb-fa32ff498bd4", + "serverExecutionDuration": 2.4519630242139 }, - "outputs": [], "source": [ "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", @@ -576,19 +1028,30 @@ " # No limit on how many generator runs will be produced\n", " num_trials=-1,\n", " model_kwargs={ # Kwargs to pass to `BoTorchModel.__init__`\n", - " \"surrogate\": Surrogate(SingleTaskGP),\n", + " \"surrogate\": Surrogate(\n", + " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", + " ),\n", " \"botorch_acqf_class\": qLogNoisyExpectedImprovement,\n", " },\n", " ),\n", " ]\n", ")" - ] + ], + "execution_count": 17, + "outputs": [] }, { "cell_type": "markdown", - "id": "logical-windsor", "metadata": { - "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + "id": "logical-windsor", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + }, + "originalKey": "ba3783ee-3d88-4e44-ad07-77de3c50f84d", + "outputsInitialized": false, + "showInput": false }, "source": [ "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" @@ -596,24 +1059,58 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "viral-cheese", "metadata": { - "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + "collapsed": false, + "executionStartTime": 1730905105295, + "executionStopTime": 1730905105570, + "id": "viral-cheese", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + }, + "originalKey": "e8aa4013-eb6f-4f50-a5f0-f963369495ed", + "outputsInitialized": true, + "requestMsgId": "e8aa4013-eb6f-4f50-a5f0-f963369495ed", + "serverExecutionDuration": 4.1721769375727 }, - "outputs": [], "source": [ "experiment = get_branin_experiment(minimize=True)\n", "\n", "assert len(experiment.trials) == 0\n", "experiment.search_space" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:58:25] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])" + }, + "metadata": {}, + "execution_count": 18 + } ] }, { "cell_type": "markdown", - "id": "incident-newspaper", "metadata": { - "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + "id": "incident-newspaper", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + }, + "originalKey": "df2e90f5-4132-4d87-989b-e6d47c748ddc", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 5a. Specifying `pending_observations`\n", @@ -624,12 +1121,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "casual-spread", "metadata": { - "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + "collapsed": false, + "executionStartTime": 1730905107391, + "executionStopTime": 1730905109351, + "id": "casual-spread", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + }, + "originalKey": "36ca75ec-b37c-498c-a487-5652cd3dc34b", + "outputsInitialized": true, + "requestMsgId": "36ca75ec-b37c-498c-a487-5652cd3dc34b", + "serverExecutionDuration": 1696.8911510194 }, - "outputs": [], "source": [ "for _ in range(10):\n", " # Produce a new generator run and attach it to experiment as a trial\n", @@ -648,13 +1154,58 @@ " trial.mark_completed()\n", "\n", " print(f\"Completed trial #{trial.index}, suggested by {generator_run._model_key}.\")" + ], + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #0, suggested by Sobol.\nCompleted trial #1, suggested by Sobol.\nCompleted trial #2, suggested by Sobol.\nCompleted trial #3, suggested by Sobol.\nCompleted trial #4, suggested by Sobol.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #5, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #6, suggested by BoTorch.\nCompleted trial #7, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #8, suggested by BoTorch.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Completed trial #9, suggested by BoTorch.\n" + ] + } ] }, { "cell_type": "markdown", - "id": "circular-vermont", "metadata": { - "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + "id": "circular-vermont", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + }, + "originalKey": "6a78ef13-fbaa-4cae-934b-d57f5807fe25", + "outputsInitialized": false, + "showInput": false }, "source": [ "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" @@ -662,21 +1213,200 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "significant-particular", "metadata": { - "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + "collapsed": false, + "executionStartTime": 1730905108832, + "executionStopTime": 1730905109399, + "id": "significant-particular", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + }, + "originalKey": "68d567be-27d6-4244-b22f-d6c53ed2d303", + "outputsInitialized": true, + "requestMsgId": "68d567be-27d6-4244-b22f-d6c53ed2d303", + "serverExecutionDuration": 32.219067099504 }, - "outputs": [], "source": [ "exp_to_df(experiment)" + ], + "execution_count": 20, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[WARNING 11-06 06:58:29] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": " trial_index arm_name trial_status ... branin x1 x2\n0 0 0_0 COMPLETED ... 79.581199 1.380743 12.280850\n1 1 1_0 COMPLETED ... 17.366840 6.989676 1.438049\n2 2 2_0 COMPLETED ... 61.299075 6.097525 7.568626\n3 3 3_0 COMPLETED ... 71.268812 -3.293570 4.231312\n4 4 4_0 COMPLETED ... 3.831238 -2.268755 10.230113\n5 5 5_0 COMPLETED ... 4.246417 -3.354258 10.886093\n6 6 6_0 COMPLETED ... 6.712767 9.467421 0.000000\n7 7 7_0 COMPLETED ... 17.508300 -5.000000 15.000000\n8 8 8_0 COMPLETED ... 2.507635 10.000000 2.251628\n9 9 9_0 COMPLETED ... 1.318731 8.983844 2.177537\n\n[10 rows x 7 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
trial_indexarm_nametrial_statusgeneration_methodbraninx1x2
000_0COMPLETEDSobol79.5811991.38074312.280850
111_0COMPLETEDSobol17.3668406.9896761.438049
222_0COMPLETEDSobol61.2990756.0975257.568626
333_0COMPLETEDSobol71.268812-3.2935704.231312
444_0COMPLETEDSobol3.831238-2.26875510.230113
555_0COMPLETEDBoTorch4.246417-3.35425810.886093
666_0COMPLETEDBoTorch6.7127679.4674210.000000
777_0COMPLETEDBoTorch17.508300-5.00000015.000000
888_0COMPLETEDBoTorch2.50763510.0000002.251628
999_0COMPLETEDBoTorch1.3187318.9838442.177537
\n
", + "application/vnd.dataresource+json": { + "schema": { + "fields": [ + { + "name": "index", + "type": "integer" + }, + { + "name": "trial_index", + "type": "integer" + }, + { + "name": "arm_name", + "type": "string" + }, + { + "name": "trial_status", + "type": "string" + }, + { + "name": "generation_method", + "type": "string" + }, + { + "name": "branin", + "type": "number" + }, + { + "name": "x1", + "type": "number" + }, + { + "name": "x2", + "type": "number" + } + ], + "primaryKey": [ + "index" + ], + "pandas_version": "1.4.0" + }, + "data": [ + { + "index": 0, + "trial_index": 0, + "arm_name": "0_0", + "trial_status": "COMPLETED", + "generation_method": "Sobol", + "branin": 79.5811993025, + "x1": 1.3807432353, + "x2": 12.2808498144 + }, + { + "index": 1, + "trial_index": 1, + "arm_name": "1_0", + "trial_status": "COMPLETED", + "generation_method": "Sobol", + "branin": 17.3668397479, + "x1": 6.9896756578, + "x2": 1.4380489429 + }, + { + "index": 2, + "trial_index": 2, + "arm_name": "2_0", + "trial_status": "COMPLETED", + "generation_method": "Sobol", + "branin": 61.2990749448, + "x1": 6.097525456, + "x2": 7.5686264969 + }, + { + "index": 3, + "trial_index": 3, + "arm_name": "3_0", + "trial_status": "COMPLETED", + "generation_method": "Sobol", + "branin": 71.268812081, + "x1": -3.2935703546, + "x2": 4.231311623 + }, + { + "index": 4, + "trial_index": 4, + "arm_name": "4_0", + "trial_status": "COMPLETED", + "generation_method": "Sobol", + "branin": 3.8312383283, + "x1": -2.2687551333, + "x2": 10.2301133936 + }, + { + "index": 5, + "trial_index": 5, + "arm_name": "5_0", + "trial_status": "COMPLETED", + "generation_method": "BoTorch", + "branin": 4.2464169491, + "x1": -3.3542581623, + "x2": 10.8860926765 + }, + { + "index": 6, + "trial_index": 6, + "arm_name": "6_0", + "trial_status": "COMPLETED", + "generation_method": "BoTorch", + "branin": 6.7127667696, + "x1": 9.4674207228, + "x2": 0 + }, + { + "index": 7, + "trial_index": 7, + "arm_name": "7_0", + "trial_status": "COMPLETED", + "generation_method": "BoTorch", + "branin": 17.5082995158, + "x1": -5, + "x2": 15 + }, + { + "index": 8, + "trial_index": 8, + "arm_name": "8_0", + "trial_status": "COMPLETED", + "generation_method": "BoTorch", + "branin": 2.5076350936, + "x1": 10, + "x2": 2.2516281613 + }, + { + "index": 9, + "trial_index": 9, + "arm_name": "9_0", + "trial_status": "COMPLETED", + "generation_method": "BoTorch", + "branin": 1.3187305399, + "x1": 8.983844442, + "x2": 2.1775366828 + } + ] + } + }, + "metadata": {}, + "execution_count": 20 + } ] }, { "cell_type": "markdown", - "id": "obvious-transparency", "metadata": { - "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783" + "id": "obvious-transparency", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "c25da720-6d3d-4f16-b878-24f2d2755783" + }, + "originalKey": "633c66af-a89f-4f03-a88b-866767d0a52f", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 6. Customizing a `Surrogate` or `Acquisition`\n", @@ -688,12 +1418,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "organizational-balance", "metadata": { - "originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af" + "collapsed": false, + "executionStartTime": 1730905111356, + "executionStopTime": 1730905111601, + "id": "organizational-balance", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "e7f8e413-f01e-4f9d-82c1-4912097637af" + }, + "originalKey": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693", + "outputsInitialized": true, + "requestMsgId": "8fb45ee5-b75f-459e-afd2-5f7e7c7d4693", + "serverExecutionDuration": 2.4916339898482 }, - "outputs": [], "source": [ "from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform\n", "from botorch.acquisition.risk_measures import RiskMeasureMCObjective\n", @@ -711,13 +1450,22 @@ " risk_measure: Optional[RiskMeasureMCObjective] = None,\n", " ) -> Tuple[Optional[MCAcquisitionObjective], Optional[PosteriorTransform]]:\n", " ... # Produce the desired `MCAcquisitionObjective` and `PosteriorTransform` instead of the default" - ] + ], + "execution_count": 21, + "outputs": [] }, { "cell_type": "markdown", - "id": "theoretical-horizon", "metadata": { - "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9" + "id": "theoretical-horizon", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "7299f0fc-e19e-4383-99de-ef7a9a987fe9" + }, + "originalKey": "0ec8606d-9d5b-4bcb-ad7e-f54839ad6f9b", + "outputsInitialized": false, + "showInput": false }, "source": [ "Then to use the new subclass in `BoTorchModel`, just specify `acquisition_class` argument along with `botorch_acqf_class` (to `BoTorchModel` directly or to `Models.BOTORCH_MODULAR`, which just passes the relevant arguments to `BoTorchModel` under the hood, as discussed in section 4):" @@ -725,12 +1473,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "approximate-rolling", "metadata": { - "originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc" + "collapsed": false, + "executionStartTime": 1730905113131, + "executionStopTime": 1730905113387, + "id": "approximate-rolling", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "07fe169a-78de-437e-9857-7c99cc48eedc" + }, + "originalKey": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab", + "outputsInitialized": true, + "requestMsgId": "d2cbf675-77f6-4bbe-9eb6-42a6834ccaab", + "serverExecutionDuration": 12.22031598445 }, - "outputs": [], "source": [ "Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", @@ -738,13 +1495,38 @@ " acquisition_class=CustomObjectiveAcquisition,\n", " botorch_acqf_class=MyAcquisitionFunctionClass,\n", ")" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[INFO 11-06 06:58:33] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "TorchModelBridge(model=BoTorchModel)" + }, + "metadata": {}, + "execution_count": 22 + } ] }, { "cell_type": "markdown", - "id": "representative-implement", "metadata": { - "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256" + "id": "representative-implement", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "608d5f0d-4528-4aa6-869d-db38fcbfb256" + }, + "originalKey": "cdcfb2bc-3016-4681-9fff-407f28321c3f", + "outputsInitialized": false, + "showInput": false }, "source": [ "To use a custom `Surrogate` subclass, pass the `surrogate` argument of that type:\n", @@ -759,9 +1541,16 @@ }, { "cell_type": "markdown", - "id": "framed-intermediate", "metadata": { - "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361" + "id": "framed-intermediate", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "64f1289e-73c7-4cc5-96ee-5091286a8361" + }, + "originalKey": "ff03d674-f584-403f-ba65-f1bab921845b", + "outputsInitialized": false, + "showInput": false }, "source": [ "------" @@ -769,9 +1558,16 @@ }, { "cell_type": "markdown", - "id": "metropolitan-feedback", "metadata": { - "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0" + "id": "metropolitan-feedback", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "d1e37569-dd0d-4561-b890-2f0097a345e0" + }, + "originalKey": "f71fcfa1-fc59-4bfb-84d6-b94ea5298bfa", + "outputsInitialized": false, + "showInput": false }, "source": [ "## Appendix 1: Methods available on `BoTorchModel`\n", @@ -787,14 +1583,22 @@ "* `update` updates surrogate model with training data and optionally reoptimizes model parameters via `Surrogate.update`,\n", "* `cross_validate` re-fits the surrogate model to subset of training data and makes predictions for test data,\n", "* `evaluate_acquisition_function` instantiates an acquisition function and evaluates it for a given point.\n", - "------\n" + "------\n", + "" ] }, { "cell_type": "markdown", - "id": "possible-transsexual", "metadata": { - "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12" + "id": "possible-transsexual", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "b02f928c-57d9-4b2a-b4fe-c6d28d368b12" + }, + "originalKey": "91cedde4-8911-441f-af05-eb124581cbbc", + "outputsInitialized": false, + "showInput": false }, "source": [ "## Appendix 2: Default surrogate models and acquisition functions\n", @@ -812,9 +1616,16 @@ }, { "cell_type": "markdown", - "id": "continuous-strain", "metadata": { - "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6" + "id": "continuous-strain", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "76ae9852-9d21-43d6-bf75-bb087a474dd6" + }, + "originalKey": "c8b0f933-8df6-479b-aa61-db75ca877624", + "outputsInitialized": false, + "showInput": false }, "source": [ "## Appendix 3: Handling storage errors that arise from objects that don't have serialization logic in A\n", @@ -824,9 +1635,16 @@ }, { "cell_type": "markdown", - "id": "broadband-voice", "metadata": { - "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc" + "id": "broadband-voice", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "6487b68e-b808-4372-b6ba-ab02ce4826bc" + }, + "originalKey": "4d82f49a-3a8b-42f0-a4f5-5c079b793344", + "outputsInitialized": false, + "showInput": false }, "source": [ "The two options for handling this error are:\n", @@ -834,26 +1652,5 @@ "2. specifying serialization logic for a given object that needs to occur among the `Model` or `AcquisitionFunction` options. Tutorial for this is in the works, but in the meantime you can [post an issue on the Ax GitHub](https://github.com/facebook/Ax/issues) to get help with this." ] } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + ] }