diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 688813ac32c..c78fae3e78c 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -10,6 +10,7 @@ import torch from ax.core.observation import ObservationFeatures +from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.modelbridge.discrete import DiscreteModelBridge from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import ( @@ -99,7 +100,8 @@ def test_SAASBO(self) -> None: SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( - saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP + saasbo.model.surrogate.model_configs[0].botorch_model_class, + SaasFullyBayesianSingleTaskGP, ) @mock_botorch_optimize @@ -459,9 +461,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: self.assertIsInstance(mtgp, TorchModelBridge) self.assertIsInstance(mtgp.model, BoTorchModel) self.assertEqual(mtgp.model.acquisition_class, Acquisition) - self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) + is_moo = isinstance( + exp.optimization_config, MultiObjectiveOptimizationConfig + ) + if is_moo: + self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP) + models = mtgp.model.surrogate.model.models + else: + models = [mtgp.model.surrogate.model] - for model in mtgp.model.surrogate.model.models: + for model in models: self.assertIsInstance( model, SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index a5e8f84aedc..03804f6f7de 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -28,6 +28,7 @@ check_outcome_dataset_match, choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, + ModelConfig, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -79,6 +80,8 @@ class SurrogateSpec: allow_batched_models: bool = True + model_configs: list[ModelConfig] = field(default_factory=list) + metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict) outcomes: list[str] = field(default_factory=list) @@ -241,13 +244,19 @@ def fit( input_transform_options=spec.input_transform_options, outcome_transform_classes=spec.outcome_transform_classes, outcome_transform_options=spec.outcome_transform_options, + model_configs=spec.model_configs, + metric_to_model_configs=spec.metric_to_model_configs, allow_batched_models=spec.allow_batched_models, ) else: self._surrogate = Surrogate() # Fit the surrogate. - self.surrogate.model_options.update(additional_model_inputs) + for config in self.surrogate.model_configs: + config.model_options.update(additional_model_inputs) + for config_list in self.surrogate.metric_to_model_configs.values(): + for config in config_list: + config.model_options.update(additional_model_inputs) self.surrogate.fit( datasets=datasets, search_space_digest=search_space_digest, diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 0fab257ffca..80ecbd2a484 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -9,6 +9,7 @@ from __future__ import annotations import inspect +import warnings from collections import OrderedDict from collections.abc import Sequence from copy import deepcopy @@ -33,6 +34,7 @@ choose_model_class, convert_to_block_design, fit_botorch_model, + ModelConfig, subset_state_dict, use_model_list, ) @@ -50,11 +52,11 @@ _argparse_type_encoder, checked_cast, checked_cast_optional, + not_none, ) from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP -from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ( ChainedInputTransform, InputPerturbation, @@ -68,6 +70,7 @@ from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from pyre_extensions import none_throws from torch import Tensor NOT_YET_FIT_MSG = ( @@ -268,6 +271,19 @@ def _set_formatted_inputs( formatted_model_inputs[input_name] = input_class(**input_options) +def _raise_deprecation_warning(*args: Any, **kwargs: Any) -> None: + for k, v in kwargs.items(): + if (v is not None and k != "mll_class") or ( + k == "mll_class" and v is not ExactMarginalLogLikelihood + ): + warnings.warn( + f"{k} is deprecated and will be removed in a future version. " + f"Please specify {k} via `model_configs`.", + DeprecationWarning, + stacklevel=3, + ) + + class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under @@ -282,15 +298,20 @@ class Surrogate(Base): BoTorch model. If None is provided a model class will be selected (either one for all outcomes or a ModelList with separate models for each outcome) will be selected automatically based off the datasets at `construct` time. + This argument is deprecated in favor of model_configs. model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any additional kwargs passed into ``BoTorchModel.fit``. + This argument is deprecated in favor of model_configs. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. - mll_options: Dictionary of options / kwargs for the MLL. + This argument is deprecated in favor of model_configs. + mll_options: Dictionary of options / kwargs for the MLL. This argument is + deprecated in favor of model_configs. outcome_transform_classes: List of BoTorch outcome transforms classes. Passed down to the BoTorch ``Model``. Multiple outcome transforms can be chained - together using ``ChainedOutcomeTransform``. + together using ``ChainedOutcomeTransform``. This argument is deprecated in + favor of model_configs. outcome_transform_options: Outcome transform classes kwargs. The keys are class string names and the values are dictionaries of outcome transform kwargs. For example, @@ -299,10 +320,12 @@ class string names and the values are dictionaries of outcome transform outcome_transform_options = { "Standardize": {"m": 1}, ` - For more options see `botorch/models/transforms/outcome.py`. + For more options see `botorch/models/transforms/outcome.py`. This argument + is deprecated in favor of model_configs. input_transform_classes: List of BoTorch input transforms classes. Passed down to the BoTorch ``Model``. Multiple input transforms will be chained together using ``ChainedInputTransform``. + This argument is deprecated in favor of model_configs. input_transform_options: Input transform classes kwargs. The keys are class string names and the values are dictionaries of input transform kwargs. For example, @@ -314,13 +337,22 @@ class string names and the values are dictionaries of input transform } ` For more input options see `botorch/models/transforms/input.py`. + This argument is deprecated in favor of model_configs. covar_module_class: Covariance module class. This gets initialized after parsing the ``covar_module_options`` in ``covar_module_argparse``, and gets passed to the model constructor as ``covar_module``. - covar_module_options: Covariance module kwargs. + This argument is deprecated in favor of model_configs. + covar_module_options: Covariance module kwargs. This argument is deprecated + in favor of model_configs. likelihood: ``Likelihood`` class. This gets initialized with ``likelihood_options`` and gets passed to the model constructor. - likelihood_options: Likelihood options. + This argument is deprecated in favor of model_configs. + likelihood_options: Likelihood options. This argument is deprecated in favor + of model_configs. + model_configs: List of model configs. Each model config is a specification of + a model. These should be used in favor of the above deprecated arguments. + metric_to_model_configs: Dictionary mapping metric names to a list of model + configs for that metric. allow_batched_models: Set to true to fit the models in a batch if supported. Set to false to fit individual models to each metric in a loop. """ @@ -339,22 +371,60 @@ def __init__( covar_module_options: dict[str, Any] | None = None, likelihood_class: type[Likelihood] | None = None, likelihood_options: dict[str, Any] | None = None, + model_configs: list[ModelConfig] | None = None, + metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> None: - self.botorch_model_class = botorch_model_class - # Copying model options to avoid mutating the original dict. - # We later update it with any additional kwargs passed into `BoTorchModel.fit`. - self.model_options: dict[str, Any] = (model_options or {}).copy() - self.mll_class = mll_class - self.mll_options: dict[str, Any] = mll_options or {} - self.outcome_transform_classes = outcome_transform_classes - self.outcome_transform_options: dict[str, Any] = outcome_transform_options or {} - self.input_transform_classes = input_transform_classes - self.input_transform_options: dict[str, Any] = input_transform_options or {} - self.covar_module_class = covar_module_class - self.covar_module_options: dict[str, Any] = covar_module_options or {} - self.likelihood_class = likelihood_class - self.likelihood_options: dict[str, Any] = likelihood_options or {} + _raise_deprecation_warning( + botorch_model_class=botorch_model_class, + model_options=model_options, + mll_class=mll_class, + mll_options=mll_options, + outcome_transform_classes=outcome_transform_classes, + outcome_transform_options=outcome_transform_options, + input_transform_classes=input_transform_classes, + input_transform_options=input_transform_options, + covar_module_class=covar_module_class, + covar_module_options=covar_module_options, + likelihood_class=likelihood_class, + likelihood_options=likelihood_options, + ) + model_configs = model_configs or [] + if len(model_configs) == 0: + model_config_kwargs = { + "botorch_model_class": botorch_model_class, + "model_options": (model_options or {}).copy(), + "mll_class": mll_class, + "mll_options": mll_options, + "outcome_transform_classes": outcome_transform_classes, + "outcome_transform_options": outcome_transform_options, + "input_transform_classes": input_transform_classes, + "input_transform_options": input_transform_options, + "covar_module_class": covar_module_class, + "covar_module_options": covar_module_options, + "likelihood_class": likelihood_class, + "likelihood_options": likelihood_options, + } + model_config_kwargs = { + k: v for k, v in model_config_kwargs.items() if v is not None + } + # pyre-fixme [6]: Incompatible parameter type [6]: In call + # `ModelConfig.__init__`, for 1st positional argument, expected + # `Dict[str, typing.Any]` but got `Union[Dict[str, typing.Any], + # Dict[str, Dict[str, typing.Any]], Sequence[Type[InputTransform]], + # Sequence[Type[OutcomeTransform]], Type[Union[MarginalLogLikelihood, + # Model]], Type[Likelihood], Type[Kernel]]`. + model_configs = [ModelConfig(**model_config_kwargs)] + self.model_configs: list[ModelConfig] = model_configs + self.metric_to_model_configs: dict[str, list[ModelConfig]] = ( + metric_to_model_configs or {} + ) + if len(self.model_configs) > 1 or any( + len(model_config) > 1 + for model_config in self.metric_to_model_configs.values() + ): + raise NotImplementedError("Only one model config per metric is supported.") + self.allow_batched_models = allow_batched_models # Store the last dataset used to fit the model for a given metric(s). # If the new dataset is identical, we will skip model fitting for that metric. @@ -377,10 +447,8 @@ def __init__( def __repr__(self) -> str: return ( f"<{self.__class__.__name__}" - f" botorch_model_class={self.botorch_model_class} " - f"mll_class={self.mll_class} " - f"outcome_transform_classes={self.outcome_transform_classes} " - f"input_transform_classes={self.input_transform_classes} " + f" model_configs={self.model_configs}," + f" metric_to_model_configs={self.metric_to_model_configs}>" ) @property @@ -404,9 +472,7 @@ def Xs(self) -> list[Tensor]: training_data = self.training_data Xs = [] for dataset in training_data: - if self.botorch_model_class == PairwiseGP and isinstance( - dataset, RankingDataset - ): + if isinstance(dataset, RankingDataset): # directly accessing the d-dim X tensor values # instead of the augmented 2*d-dim dataset.X from RankingDataset Xi = checked_cast(SliceContainer, dataset._X).values @@ -431,7 +497,8 @@ def _construct_model( self, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - botorch_model_class: type[Model], + model_config: ModelConfig, + default_botorch_model_class: type[Model], state_dict: OrderedDict[str, Tensor] | None, refit: bool, ) -> Model: @@ -446,19 +513,24 @@ def _construct_model( multi-output case, where training data is formatted with just one X and concatenated Ys). search_space_digest: Search space digest used to set up model arguments. - botorch_model_class: ``Model`` class to be used as the underlying - BoTorch model. + model_config: The model_config. + default_botorch_model_class: The default ``Model`` class to be used as the + underlying BoTorch model, if the model_config does not specify one. state_dict: Optional state dict to load. This should be subsetted for the current submodel being constructed. refit: Whether to re-optimize model parameters. """ outcome_names = tuple(dataset.outcome_names) + botorch_model_class = ( + model_config.botorch_model_class or default_botorch_model_class + ) if self._should_reuse_last_model( dataset=dataset, botorch_model_class=botorch_model_class ): return self._submodels[outcome_names] formatted_model_inputs = submodel_input_constructor( botorch_model_class, # Do not pass as kwarg since this is used to dispatch. + model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=self, @@ -469,7 +541,9 @@ def _construct_model( model.load_state_dict(state_dict) if state_dict is None or refit: fit_botorch_model( - model=model, mll_class=self.mll_class, mll_options=self.mll_options + model=model, + mll_class=model_config.mll_class, + mll_options=model_config.mll_options, ) self._submodels[outcome_names] = model self._last_datasets[outcome_names] = dataset @@ -539,14 +613,21 @@ def fit( # To determine whether to use ModelList under the hood, we need to check for # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. - botorch_model_class = self.botorch_model_class or choose_model_class( - datasets=datasets, search_space_digest=search_space_digest - ) - + if ( + len(self.model_configs) == 1 + and self.model_configs[0].botorch_model_class is None + ): + default_botorch_model_class = choose_model_class( + datasets=datasets, search_space_digest=search_space_digest + ) + else: + default_botorch_model_class = self.model_configs[0].botorch_model_class should_use_model_list = use_model_list( datasets=datasets, - botorch_model_class=botorch_model_class, + botorch_model_class=not_none(default_botorch_model_class), + model_configs=self.model_configs, allow_batched_models=self.allow_batched_models, + metric_to_model_configs=self.metric_to_model_configs, ) if not should_use_model_list and len(datasets) > 1: @@ -564,10 +645,30 @@ def fit( ) else: submodel_state_dict = state_dict + model_config = None + if len(self.metric_to_model_configs) > 0: + # if metric_to_model_configs is not empty, then + # we are using a model list and each dataset + # should have only one outcome. + if len(dataset.outcome_names) > 1: + raise ValueError( + "Each dataset should have only one outcome when " + "metric_to_model_configs is specified." + ) + model_config_list = self.metric_to_model_configs.get( + dataset.outcome_names[0] + ) + + # TODO: add support for automated model selection + if model_config_list is not None: + model_config = model_config_list[0] + if model_config is None: + model_config = self.model_configs[0] model = self._construct_model( dataset=dataset, search_space_digest=search_space_digest, - botorch_model_class=botorch_model_class, + model_config=model_config, + default_botorch_model_class=not_none(default_botorch_model_class), state_dict=submodel_state_dict, refit=refit, ) @@ -728,23 +829,12 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: as kwargs on reinstantiation. """ return { - "botorch_model_class": self.botorch_model_class, - "model_options": self.model_options, - "mll_class": self.mll_class, - "mll_options": self.mll_options, - "outcome_transform_classes": self.outcome_transform_classes, - "outcome_transform_options": self.outcome_transform_options, - "input_transform_classes": self.input_transform_classes, - "input_transform_options": self.input_transform_options, - "covar_module_class": self.covar_module_class, - "covar_module_options": self.covar_module_options, - "likelihood_class": self.likelihood_class, - "likelihood_options": self.likelihood_options, - "allow_batched_models": self.allow_batched_models, + "model_configs": self.model_configs, + "metric_to_model_configs": self.metric_to_model_configs, } def _extract_construct_input_transform_args( - self, search_space_digest: SearchSpaceDigest + self, model_config: ModelConfig, search_space_digest: SearchSpaceDigest ) -> tuple[Sequence[type[InputTransform]] | None, dict[str, dict[str, Any]]]: """ Extracts input transform classes and input transform options that will @@ -777,19 +867,19 @@ def _extract_construct_input_transform_args( InputPerturbation ] - if self.input_transform_classes is not None: + if model_config.input_transform_classes is not None: # TODO: Support mixing with user supplied transforms. raise NotImplementedError( "User supplied input transforms are not supported " "in robust optimization." ) else: - submodel_input_transform_classes = self.input_transform_classes - submodel_input_transform_options = self.input_transform_options + submodel_input_transform_classes = model_config.input_transform_classes + submodel_input_transform_options = model_config.input_transform_options return ( submodel_input_transform_classes, - submodel_input_transform_options, + none_throws(submodel_input_transform_options), ) @property @@ -811,6 +901,7 @@ def outcomes(self, value: list[str]) -> None: @submodel_input_constructor.register(Model) def _submodel_input_constructor_base( botorch_model_class: type[Model], + model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, @@ -819,6 +910,7 @@ def _submodel_input_constructor_base( Args: botorch_model_class: The BoTorch model class to instantiate. + model_config: The model config. dataset: The training data for the model. search_space_digest: Search space digest used to set up model arguments. surrogate: A reference to the surrogate that created the model. @@ -836,12 +928,12 @@ def _submodel_input_constructor_base( input_transform_classes, input_transform_options, ) = surrogate._extract_construct_input_transform_args( - search_space_digest=search_space_digest + model_config=model_config, search_space_digest=search_space_digest ) formatted_model_inputs = botorch_model_class.construct_inputs( training_data=dataset, - **surrogate.model_options, + **model_config.model_options, **model_kwargs_from_ss, ) @@ -851,14 +943,18 @@ def _submodel_input_constructor_base( inputs=[ ( "covar_module", - surrogate.covar_module_class, - surrogate.covar_module_options, + model_config.covar_module_class, + model_config.covar_module_options, + ), + ( + "likelihood", + model_config.likelihood_class, + model_config.likelihood_options, ), - ("likelihood", surrogate.likelihood_class, surrogate.likelihood_options), ( "outcome_transform", - surrogate.outcome_transform_classes, - surrogate.outcome_transform_options, + model_config.outcome_transform_classes, + model_config.outcome_transform_options, ), ( "input_transform", @@ -880,6 +976,7 @@ def _submodel_input_constructor_base( @submodel_input_constructor.register(MultiTaskGP) def _submodel_input_constructor_mtgp( botorch_model_class: type[Model], + model_config: ModelConfig, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, surrogate: Surrogate, @@ -888,6 +985,7 @@ def _submodel_input_constructor_mtgp( raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, + model_config=model_config, dataset=dataset, search_space_digest=search_space_digest, surrogate=surrogate, diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 7b8be177010..e78814f5675 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -9,6 +9,7 @@ import warnings from collections import OrderedDict from collections.abc import Sequence +from dataclasses import dataclass, field from logging import Logger from typing import Any @@ -34,8 +35,13 @@ from botorch.models.model import Model, ModelList from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian +from gpytorch.kernels.kernel import Kernel +from gpytorch.likelihoods import Likelihood +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor @@ -43,20 +49,124 @@ logger: Logger = get_logger(__name__) +@dataclass +class ModelConfig: + """Configuration for the BoTorch Model used in Surrogate. + + Args: + botorch_model_class: ``Model`` class to be used as the underlying + BoTorch model. If None is provided a model class will be selected (either + one for all outcomes or a ModelList with separate models for each outcome) + will be selected automatically based off the datasets at `construct` time. + This argument is deprecated in favor of model_configs. + model_options: Dictionary of options / kwargs for the BoTorch + ``Model`` constructed during ``Surrogate.fit``. + Note that the corresponding attribute will later be updated to include any + additional kwargs passed into ``BoTorchModel.fit``. + This argument is deprecated in favor of model_configs. + mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. + This argument is deprecated in favor of model_configs. + mll_options: Dictionary of options / kwargs for the MLL. This argument is + deprecated in favor of model_configs. + outcome_transform_classes: List of BoTorch outcome transforms classes. Passed + down to the BoTorch ``Model``. Multiple outcome transforms can be chained + together using ``ChainedOutcomeTransform``. This argument is deprecated in + favor of model_configs. + outcome_transform_options: Outcome transform classes kwargs. The keys are + class string names and the values are dictionaries of outcome transform + kwargs. For example, + ` + outcome_transform_classes = [Standardize] + outcome_transform_options = { + "Standardize": {"m": 1}, + ` + For more options see `botorch/models/transforms/outcome.py`. This argument + is deprecated in favor of model_configs. + input_transform_classes: List of BoTorch input transforms classes. + Passed down to the BoTorch ``Model``. Multiple input transforms + will be chained together using ``ChainedInputTransform``. + This argument is deprecated in favor of model_configs. + input_transform_options: Input transform classes kwargs. The keys are + class string names and the values are dictionaries of input transform + kwargs. For example, + ` + input_transform_classes = [Normalize, Round] + input_transform_options = { + "Normalize": {"d": 3}, + "Round": {"integer_indices": [0], "categorical_features": {1: 2}}, + } + ` + For more input options see `botorch/models/transforms/input.py`. + This argument is deprecated in favor of model_configs. + covar_module_class: Covariance module class. This gets initialized after + parsing the ``covar_module_options`` in ``covar_module_argparse``, + and gets passed to the model constructor as ``covar_module``. + This argument is deprecated in favor of model_configs. + covar_module_options: Covariance module kwargs. This argument is deprecated + in favor of model_configs. + likelihood: ``Likelihood`` class. This gets initialized with + ``likelihood_options`` and gets passed to the model constructor. + This argument is deprecated in favor of model_configs. + likelihood_options: Likelihood options. This argument is deprecated in favor + of model_configs. + """ + + botorch_model_class: type[Model] | None = None + model_options: dict[str, Any] = field(default_factory=dict) + mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood + mll_options: dict[str, Any] = field(default_factory=dict) + input_transform_classes: list[type[InputTransform]] | None = None + input_transform_options: dict[str, dict[str, Any]] | None = field( + default_factory=dict + ) + outcome_transform_classes: list[type[OutcomeTransform]] | None = None + outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict) + covar_module_class: type[Kernel] | None = None + covar_module_options: dict[str, Any] = field(default_factory=dict) + likelihood_class: type[Likelihood] | None = None + likelihood_options: dict[str, Any] = field(default_factory=dict) + + def use_model_list( datasets: Sequence[SupervisedDataset], botorch_model_class: type[Model], + model_configs: list[ModelConfig] | None = None, + metric_to_model_configs: dict[str, list[ModelConfig]] | None = None, allow_batched_models: bool = True, ) -> bool: - if issubclass(botorch_model_class, MultiTaskGP): - # We currently always wrap multi-task models into `ModelListGP`. + model_configs = model_configs or [] + metric_to_model_configs = metric_to_model_configs or {} + if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1: + # There is only one outcome, so we can use a single model. + return False + elif ( + len(model_configs) > 1 + or len(metric_to_model_configs) > 0 + or any(len(model_config) for model_config in metric_to_model_configs.values()) + ): + # There are multiple outcomes and outcomes might be modeled with different + # models return True - elif issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): + # Otherwise, the same model class is used for all outcomes. + # Determine what the model class is. + if len(model_configs) > 0: + botorch_model_class = ( + model_configs[0].botorch_model_class or botorch_model_class + ) + if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): # SAAS models do not support multiple outcomes. # Use model list if there are multiple outcomes. return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 + elif issubclass(botorch_model_class, MultiTaskGP): + # We wrap multi-task models into `ModelListGP` when there are + # multiple outcomes. + return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 elif len(datasets) == 1: - # Just one outcome, can use single model. + # This method is called before multiple datasets are merged into + # one if using a batched model. If there is one dataset here, + # there should be a reason that a single model should be used: + # e.g. a contextual model, where we want to jointly model the metric + # each context (and context-level metrics are different outcomes). return False elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all( torch.equal(datasets[0].X, ds.X) for ds in datasets[1:] diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index c8cf1ddba7a..389b3072175 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -27,6 +27,7 @@ from ax.models.torch.botorch_modular.utils import ( choose_model_class, construct_acquisition_and_optimizer_options, + ModelConfig, ) from ax.models.torch.utils import _filter_X_observed from ax.models.torch_base import TorchOptConfig @@ -327,7 +328,21 @@ def test_fit(self, mock_fit: Mock) -> None: mock_fit.assert_called_with( dataset=self.block_design_training_data[0], search_space_digest=self.mf_search_space_digest, - botorch_model_class=SingleTaskMultiFidelityGP, + model_config=ModelConfig( + botorch_model_class=None, + model_options={}, + mll_class=ExactMarginalLogLikelihood, + mll_options={}, + input_transform_classes=None, + input_transform_options={}, + outcome_transform_classes=None, + outcome_transform_options={}, + covar_module_class=None, + covar_module_options={}, + likelihood_class=None, + likelihood_options={}, + ), + default_botorch_model_class=SingleTaskMultiFidelityGP, state_dict=None, refit=True, ) @@ -727,6 +742,8 @@ def test_surrogate_model_options_propagation( input_transform_options=None, outcome_transform_classes=None, outcome_transform_options=None, + model_configs=[], + metric_to_model_configs={}, allow_batched_models=True, ) @@ -755,6 +772,8 @@ def test_surrogate_options_propagation( input_transform_options=None, outcome_transform_classes=None, outcome_transform_options=None, + model_configs=[], + metric_to_model_configs={}, allow_batched_models=False, ) diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 3aa5a733214..8392dd7a6eb 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -18,8 +18,13 @@ from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition +from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate -from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model +from ax.models.torch.botorch_modular.utils import ( + choose_model_class, + fit_botorch_model, + ModelConfig, +) from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast, not_none @@ -37,7 +42,7 @@ from botorch.models.transforms.outcome import Standardize from botorch.utils.datasets import SupervisedDataset from gpytorch.constraints import GreaterThan, Interval -from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import Kernel, LinearKernel, MaternKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood from pyre_extensions import assert_is_instance @@ -216,8 +221,10 @@ def _get_surrogate( def test_init(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) - self.assertEqual(surrogate.botorch_model_class, botorch_model_class) - self.assertEqual(surrogate.mll_class, self.mll_class) + self.assertEqual( + surrogate.model_configs[0].botorch_model_class, botorch_model_class + ) + self.assertEqual(surrogate.model_configs[0].mll_class, self.mll_class) self.assertTrue(surrogate.allow_batched_models) # True by default def test_clone_reset(self) -> None: @@ -432,7 +439,8 @@ def test_construct_model(self) -> None: Surrogate()._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=Model, + model_config=ModelConfig(), + default_botorch_model_class=Model, state_dict=None, refit=True, ) @@ -446,7 +454,8 @@ def test_construct_model(self) -> None: model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=botorch_model_class, + model_config=surrogate.model_configs[0], + default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, ) @@ -471,7 +480,8 @@ def test_construct_model(self) -> None: new_model = surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=botorch_model_class, + model_config=surrogate.model_configs[0], + default_botorch_model_class=botorch_model_class, state_dict=None, refit=True, ) @@ -485,7 +495,8 @@ def test_construct_model(self) -> None: surrogate._construct_model( dataset=self.training_data[0], search_space_digest=self.search_space_digest, - botorch_model_class=SingleTaskGPWithDifferentConstructor, + model_config=ModelConfig(), + default_botorch_model_class=SingleTaskGPWithDifferentConstructor, state_dict=None, refit=True, ) @@ -497,14 +508,18 @@ def test_construct_model(self) -> None: ) @mock_botorch_optimize - def test_construct_custom_model(self) -> None: + def test_construct_custom_model(self, use_model_config: bool = False) -> None: # Test error for unsupported covar_module and likelihood. - surrogate = Surrogate( - botorch_model_class=SingleTaskGPWithDifferentConstructor, - mll_class=self.mll_class, - covar_module_class=RBFKernel, - likelihood_class=FixedNoiseGaussianLikelihood, - ) + model_config_kwargs: dict[str, Any] = { + "botorch_model_class": SingleTaskGPWithDifferentConstructor, + "mll_class": self.mll_class, + "covar_module_class": RBFKernel, + "likelihood_class": FixedNoiseGaussianLikelihood, + } + if use_model_config: + surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + else: + surrogate = Surrogate(**model_config_kwargs) with self.assertRaisesRegex(UserInputError, "does not support"): surrogate.fit( self.training_data, @@ -512,14 +527,18 @@ def test_construct_custom_model(self) -> None: ) # Pass custom options to a SingleTaskGP and make sure they are used noise_constraint = Interval(1e-6, 1e-1) - surrogate = Surrogate( - botorch_model_class=SingleTaskGP, - mll_class=LeaveOneOutPseudoLikelihood, - covar_module_class=RBFKernel, - covar_module_options={"ard_num_dims": 3}, - likelihood_class=GaussianLikelihood, - likelihood_options={"noise_constraint": noise_constraint}, - ) + model_config_kwargs = { + "botorch_model_class": SingleTaskGP, + "mll_class": LeaveOneOutPseudoLikelihood, + "covar_module_class": RBFKernel, + "covar_module_options": {"ard_num_dims": 3}, + "likelihood_class": GaussianLikelihood, + "likelihood_options": {"noise_constraint": noise_constraint}, + } + if use_model_config: + surrogate = Surrogate(model_configs=[ModelConfig(**model_config_kwargs)]) + else: + surrogate = Surrogate(**model_config_kwargs) surrogate.fit( self.training_data, search_space_digest=self.search_space_digest, @@ -532,10 +551,65 @@ def test_construct_custom_model(self) -> None: model.likelihood.noise_covar.raw_noise_constraint.__dict__, noise_constraint.__dict__, ) - self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood) + self.assertEqual( + surrogate.model_configs[0].mll_class, LeaveOneOutPseudoLikelihood + ) self.assertEqual(type(model.covar_module), RBFKernel) self.assertEqual(model.covar_module.ard_num_dims, 3) + def test_construct_custom_model_with_config(self) -> None: + self.test_construct_custom_model(use_model_config=True) + + def test_construct_model_with_metric_to_model_configs(self) -> None: + surrogate = Surrogate( + metric_to_model_configs={ + "metric": [ModelConfig()], + "metric2": [ModelConfig(covar_module_class=ScaleMaternKernel)], + }, + model_configs=[ModelConfig(covar_module_class=LinearKernel)], + ) + training_data = self.training_data + [ + SupervisedDataset( + X=self.Xs[0], + # Note: using 1d Y does not match the 2d TorchOptConfig + Y=self.Ys[0], + feature_names=self.feature_names, + outcome_names=[f"metric{i}"], + ) + for i in range(2, 5) + ] + surrogate.fit( + datasets=training_data, search_space_digest=self.search_space_digest + ) + # test model follows metric_to_model_configs for + # first two metrics + self.assertIsInstance(surrogate.model, ModelListGP) + submodels = surrogate.model.models + self.assertEqual(len(submodels), 4) + for m in submodels: + self.assertIsInstance(m, SingleTaskGP) + self.assertIsInstance(surrogate.model.models[1].covar_module, ScaleKernel) + self.assertIsInstance( + surrogate.model.models[1].covar_module.base_kernel, MaternKernel + ) + self.assertIsInstance(surrogate.model.models[0].covar_module, RBFKernel) + # test model use model_configs for the third metric + self.assertIsInstance(surrogate.model.models[2].covar_module, LinearKernel) + + def test_multiple_model_configs_error(self) -> None: + with self.assertRaisesRegex( + NotImplementedError, "Only one model config per metric is supported." + ): + Surrogate( + model_configs=[ModelConfig(), ModelConfig()], + ) + with self.assertRaisesRegex( + NotImplementedError, "Only one model config per metric is supported." + ): + Surrogate( + metric_to_model_configs={"metric": [ModelConfig(), ModelConfig()]}, + ) + @mock_botorch_optimize @patch(f"{SURROGATE_PATH}.predict_from_model") def test_predict(self, mock_predict: Mock) -> None: @@ -647,7 +721,8 @@ def test_serialize_attributes_as_kwargs(self) -> None: for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class) expected = { - k: v for k, v in surrogate.__dict__.items() if not k.startswith("_") + "model_configs": surrogate.model_configs, + "metric_to_model_configs": surrogate.metric_to_model_configs, } self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected) @@ -677,7 +752,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.input_transform_classes = [Normalize] + surrogate.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.training_data, @@ -830,11 +905,12 @@ def setUp(self) -> None: ) def test_init(self) -> None: + model_config = self.surrogate.model_configs[0] self.assertEqual( - [self.surrogate.botorch_model_class] * 2, + [model_config.botorch_model_class] * 2, [*self.botorch_submodel_class_per_outcome.values()], ) - self.assertEqual(self.surrogate.mll_class, self.mll_class) + self.assertEqual(model_config.mll_class, self.mll_class) with self.assertRaisesRegex( ValueError, "BoTorch `Model` has not yet been constructed" ): @@ -849,7 +925,7 @@ def test_init(self) -> None: def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock ) -> None: - self.surrogate.model_options.update({"output_tasks": [2]}) + self.surrogate.model_configs[0].model_options.update({"output_tasks": [2]}) for fixed_noise in (False, True): mock_fit.reset_mock() mock_MTGP_construct_inputs.reset_mock() @@ -940,10 +1016,14 @@ def test_fit( self.assertIsNone(surrogate._model) # Should instantiate mll and `fit_gpytorch_mll` when `state_dict` # is `None`. - # pyre-ignore[6]: Incompatible parameter type: In call - # `issubclass`, for 1st positional argument, expected - # `Type[typing.Any]` but got `Optional[Type[Model]]`. - is_mtgp = issubclass(surrogate.botorch_model_class, MultiTaskGP) + + is_mtgp = issubclass( + # pyre-ignore[6]: Incompatible parameter type: In call + # `issubclass`, for 1st positional argument, expected + # `Type[typing.Any]` but got `Optional[Type[Model]]`. + surrogate.model_configs[0].botorch_model_class, + MultiTaskGP, + ) search_space_digest = ( self.multi_task_search_space_digest if is_mtgp @@ -1097,25 +1177,6 @@ def test_with_botorch_transforms(self) -> None: ) ) - def test_serialize_attributes_as_kwargs(self) -> None: - # TODO[mpolson64] Reimplement this when serialization has been sorted out - pass - # expected = self.surrogate.__dict__ - # # The two attributes below don't need to be saved as part of state, - # # so we remove them from the expected dict. - # for attr_name in ( - # "botorch_model_class", - # "model_options", - # "covar_module_class", - # "covar_module_options", - # "likelihood_class", - # "likelihood_options", - # "outcome_transform", - # "input_transform", - # ): - # expected.pop(attr_name) - # self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(), expected) - @mock_botorch_optimize def test_construct_custom_model(self) -> None: noise_constraint = Interval(1e-4, 10.0) @@ -1144,7 +1205,9 @@ def test_construct_custom_model(self) -> None: ) models = checked_cast(ModelListGP, surrogate._model).models self.assertEqual(len(models), 2) - self.assertEqual(surrogate.mll_class, ExactMarginalLogLikelihood) + self.assertEqual( + surrogate.model_configs[0].mll_class, ExactMarginalLogLikelihood + ) # Make sure we properly copied the transforms. self.assertNotEqual( id(models[0].input_transform), id(models[1].input_transform) @@ -1197,7 +1260,7 @@ def test_w_robust_digest(self) -> None: environmental_variables=[], multiplicative=False, ) - surrogate.input_transform_classes = [Normalize] + surrogate.model_configs[0].input_transform_classes = [Normalize] with self.assertRaisesRegex(NotImplementedError, "input transforms"): surrogate.fit( datasets=self.supervised_training_data, diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 4ab0f4677c4..fa34a47d9c6 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -305,7 +305,7 @@ def test_use_model_list(self) -> None: botorch_model_class=SingleTaskGP, ) ) - self.assertTrue( + self.assertFalse( use_model_list( datasets=self.supervised_datasets, botorch_model_class=MultiTaskGP ) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 965ad491397..b2032c51ce6 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -46,6 +46,7 @@ ) from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.utils import ModelConfig from ax.storage.json_store.decoders import ( batch_trial_from_json, botorch_component_from_json, @@ -228,7 +229,7 @@ def object_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - elif _class in (SurrogateSpec, Surrogate): + elif _class in (SurrogateSpec, Surrogate, ModelConfig): if "input_transform" in object_json: ( input_transform_classes_json, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index a798c30b1fe..02498d03669 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -98,6 +98,7 @@ from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.utils import ModelConfig from ax.models.winsorization_config import WinsorizationConfig from ax.runners.synthetic import SyntheticRunner from ax.service.utils.scheduler_options import SchedulerOptions, TrialType @@ -330,6 +331,7 @@ "AuxiliaryExperimentCheck": AuxiliaryExperimentCheck, "Models": Models, "ModelRegistryBase": ModelRegistryBase, + "ModelConfig": ModelConfig, "ModelSpec": ModelSpec, "MultiObjective": MultiObjective, "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index 275530f2f3c..3d245b57d3d 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -2,12 +2,55 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "about-preview", + "execution_count": 2, "metadata": { - "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + "collapsed": false, + "executionStartTime": 1730390425552, + "executionStopTime": 1730390441803, + "id": "about-preview", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "cca773d8-5e94-4b5a-ae54-22295be8936a" + }, + "originalKey": "a473b658-203c-404c-bdfc-e8fcd622ca0e", + "outputsInitialized": true, + "requestMsgId": "a473b658-203c-404c-bdfc-e8fcd622ca0e", + "serverExecutionDuration": 5929.5586600201 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I1031 090036.925 _utils_internal.py:321] NCCL_DEBUG env var is set to None\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I1031 090036.926 _utils_internal.py:339] NCCL_DEBUG is forced to WARN from None\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;5;200mRemapping module ax.models.torch.botorch_modular.model from /mnt/xarfuse/uid-28351/02d88658-seed-nspid4026533191_cgpid3387601-ns-4026533188/ax/models/torch/botorch_modular/model.py to /data/sandcastle/boxes/fbsource/fbcode/ax/models/torch/botorch_modular/model.py\u001b[0;0m\n", + "\u001b[38;5;200mRemapping module ax.models.torch.botorch_modular.surrogate from /mnt/xarfuse/uid-28351/02d88658-seed-nspid4026533191_cgpid3387601-ns-4026533188/ax/models/torch/botorch_modular/surrogate.py to /data/sandcastle/boxes/fbsource/fbcode/ax/models/torch/botorch_modular/surrogate.py\u001b[0;0m\n", + "\u001b[38;5;200mRemapping module ax.models.torch.botorch_modular.utils from /mnt/xarfuse/uid-28351/02d88658-seed-nspid4026533191_cgpid3387601-ns-4026533188/ax/models/torch/botorch_modular/utils.py to /data/sandcastle/boxes/fbsource/fbcode/ax/models/torch/botorch_modular/utils.py\u001b[0;0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;5;200mRemapping module ax.storage.json_store.decoder from /mnt/xarfuse/uid-28351/02d88658-seed-nspid4026533191_cgpid3387601-ns-4026533188/ax/storage/json_store/decoder.py to /data/sandcastle/boxes/fbsource/fbcode/ax/storage/json_store/decoder.py\u001b[0;0m\n", + "\u001b[38;5;200mRemapping module ax.storage.json_store.registry from /mnt/xarfuse/uid-28351/02d88658-seed-nspid4026533191_cgpid3387601-ns-4026533188/ax/storage/json_store/registry.py to /data/sandcastle/boxes/fbsource/fbcode/ax/storage/json_store/registry.py\u001b[0;0m\n" + ] + } + ], "source": [ "from typing import Any, Dict, Optional, Tuple, Type\n", "\n", @@ -43,9 +86,16 @@ }, { "cell_type": "markdown", - "id": "northern-affairs", "metadata": { - "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + "id": "northern-affairs", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "58ea5ebf-ff3a-40b4-8be3-1b85c99d1c4a" + }, + "originalKey": "c9a665ca-497e-4d7c-bbb5-1b9f8d1d311c", + "outputsInitialized": false, + "showInput": false }, "source": [ "# Setup and Usage of BoTorch Models in Ax\n", @@ -70,9 +120,16 @@ }, { "cell_type": "markdown", - "id": "pending-support", "metadata": { - "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + "id": "pending-support", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "c06d1b5c-067d-4618-977e-c8269a98bd0a" + }, + "originalKey": "4706d02e-6b3f-4161-9e08-f5a31328b1d1", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 1. Quick-start example\n", @@ -82,12 +139,31 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "parental-sending", + "execution_count": 3, "metadata": { - "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + "collapsed": false, + "executionStartTime": 1730390501452, + "executionStopTime": 1730390501699, + "id": "parental-sending", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "72934cf2-4ecf-483a-93bd-4df88b19a7b8" + }, + "originalKey": "4b8633f7-b281-4234-afb2-8b64af8d6a94", + "outputsInitialized": true, + "requestMsgId": "4b8633f7-b281-4234-afb2-8b64af8d6a94", + "serverExecutionDuration": 56.608530983794 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:01:41] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + } + ], "source": [ "experiment = get_branin_experiment(with_trial=True)\n", "data = get_branin_data(trials=[experiment.trials[0]])" @@ -95,12 +171,31 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "rough-somerset", + "execution_count": 4, "metadata": { - "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + "collapsed": false, + "executionStartTime": 1730390502117, + "executionStopTime": 1730390504217, + "id": "rough-somerset", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "e571212c-7872-4ebc-b646-8dad8d4266fd" + }, + "originalKey": "e4519be2-9e06-422c-b5a2-215bc79a9690", + "outputsInitialized": true, + "requestMsgId": "e4519be2-9e06-422c-b5a2-215bc79a9690", + "serverExecutionDuration": 1917.6382739679 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:01:42] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + } + ], "source": [ "# `Models` automatically selects a model + model bridge combination.\n", "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", @@ -114,9 +209,16 @@ }, { "cell_type": "markdown", - "id": "hairy-wiring", "metadata": { - "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + "id": "hairy-wiring", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "fba91372-7aa6-456d-a22b-78ab30c26cd8" + }, + "originalKey": "46f5c2c7-400d-4d8d-b0b9-a241657b173f", + "outputsInitialized": false, + "showInput": false }, "source": [ "Now we can use this model to generate candidates (`gen`), predict outcome at a point (`predict`), or evaluate acquisition function value at a given point (`evaluate_acquisition_function`)." @@ -124,12 +226,34 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "consecutive-summary", + "execution_count": 5, "metadata": { - "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + "collapsed": false, + "executionStartTime": 1730390503548, + "executionStopTime": 1730390504462, + "id": "consecutive-summary", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "59582fc6-8089-4320-864e-d98ee271d4f7" + }, + "originalKey": "50fcf9e0-9f78-46ae-90eb-7b1a03be8c9f", + "outputsInitialized": true, + "requestMsgId": "50fcf9e0-9f78-46ae-90eb-7b1a03be8c9f", + "serverExecutionDuration": 243.41962195467 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Arm(parameters={'x1': -5.0, 'x2': 15.0})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "generator_run = model_bridge_with_GPEI.gen(n=1)\n", "generator_run.arms[0]" @@ -137,9 +261,16 @@ }, { "cell_type": "markdown", - "id": "diverse-richards", "metadata": { - "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + "id": "diverse-richards", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "8cfe0fa9-8cce-4718-ba43-e8a63744d626" + }, + "originalKey": "804bac30-db07-4444-98a2-7a5f05007495", + "outputsInitialized": false, + "showInput": false }, "source": [ "-----\n", @@ -151,9 +282,16 @@ }, { "cell_type": "markdown", - "id": "grand-committee", "metadata": { - "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + "id": "grand-committee", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "7037fd14-bcfe-44f9-b915-c23915d2bda9" + }, + "originalKey": "31b54ce5-2590-4617-b10c-d24ed3cce51d", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 2. BoTorchModel = Surrogate + Acquisition\n", @@ -163,9 +301,16 @@ }, { "cell_type": "markdown", - "id": "thousand-blanket", "metadata": { - "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + "id": "thousand-blanket", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "08b12c6c-14da-4342-95bd-f607a131ce9d" + }, + "originalKey": "4a4e006e-07fa-4d63-8b9a-31b67075e40e", + "outputsInitialized": false, + "showInput": false }, "source": [ "### 2A. Example that uses defaults and requires no options\n", @@ -175,10 +320,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "changing-xerox", + "execution_count": 6, "metadata": { - "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + "collapsed": false, + "executionStartTime": 1730390506087, + "executionStopTime": 1730390506305, + "id": "changing-xerox", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "b1bca702-07b2-4818-b2b9-2107268c383c" + }, + "originalKey": "34c9101c-7239-4782-9c1b-ed07cea3a315", + "outputsInitialized": true, + "requestMsgId": "34c9101c-7239-4782-9c1b-ed07cea3a315", + "serverExecutionDuration": 1.7009290168062 }, "outputs": [], "source": [ @@ -196,9 +352,16 @@ }, { "cell_type": "markdown", - "id": "lovely-mechanics", "metadata": { - "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + "id": "lovely-mechanics", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "5cec0f06-ae2c-47d3-bd95-441c45762e38" + }, + "originalKey": "7b9fae38-fe5d-4e5b-8b5f-2953c1ef09d2", + "outputsInitialized": false, + "showInput": false }, "source": [ "### 2B. Example with all the options\n", @@ -207,10 +370,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "twenty-greek", + "execution_count": 7, "metadata": { - "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + "collapsed": false, + "executionStartTime": 1730390507109, + "executionStopTime": 1730390507616, + "id": "twenty-greek", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "25b13c48-edb0-4b3f-ba34-4f4a4176162a" + }, + "originalKey": "7f539164-4fa5-492c-a530-8e119ac06399", + "outputsInitialized": true, + "requestMsgId": "7f539164-4fa5-492c-a530-8e119ac06399", + "serverExecutionDuration": 1.9234410137869 }, "outputs": [], "source": [ @@ -242,9 +416,16 @@ }, { "cell_type": "markdown", - "id": "fourth-material", "metadata": { - "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + "id": "fourth-material", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" + }, + "originalKey": "7140bb19-09b4-4abe-951d-53902ae07833", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 2C. `Surrogate` and `Acquisition` Q&A\n", @@ -258,9 +439,16 @@ }, { "cell_type": "markdown", - "id": "violent-course", "metadata": { - "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + "id": "violent-course", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "86018ee5-f7b8-41ae-8e2d-460fe5f0c15b" + }, + "originalKey": "71f92895-874d-4fc7-ae87-a5519b18d1a0", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 3. I know which Botorch `Model` and `AcquisitionFunction` I'd like to combine in Ax. How do set this up?" @@ -268,11 +456,18 @@ }, { "cell_type": "markdown", - "id": "unlike-football", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "id": "unlike-football", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "b29a846d-d7bc-4143-8318-10170c9b4298", + "showInput": false + }, + "originalKey": "4af8afa2-5056-46be-b7b9-428127e668cc", + "outputsInitialized": false, "showInput": false }, "source": [ @@ -286,12 +481,23 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "dynamic-university", + "execution_count": 8, "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + "collapsed": false, + "executionStartTime": 1730390509122, + "executionStopTime": 1730390509315, + "id": "dynamic-university", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6c2ea955-c7a4-42ff-a4d7-f787113d4d53" + }, + "originalKey": "1c4d7794-1561-412c-9bda-fdd5c575149f", + "outputsInitialized": true, + "requestMsgId": "1c4d7794-1561-412c-9bda-fdd5c575149f", + "serverExecutionDuration": 2.4470969801769 }, "outputs": [], "source": [ @@ -326,9 +532,16 @@ }, { "cell_type": "markdown", - "id": "otherwise-context", "metadata": { - "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + "id": "otherwise-context", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "b9072296-956d-4add-b1f6-e7e0415ba65c" + }, + "originalKey": "5a27fd2c-4c4c-41fe-a634-f6d0ec4f1666", + "outputsInitialized": false, + "showInput": false }, "source": [ "NOTE: if you run into a case where base `Surrogate` does not work with your BoTorch `Model`, please let us know in this Github issue: https://github.com/facebook/Ax/issues/363, so we can find the right solution and augment this tutorial." @@ -336,9 +549,16 @@ }, { "cell_type": "markdown", - "id": "northern-invite", "metadata": { - "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + "id": "northern-invite", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "335cabdf-2bf6-48e8-ba0c-1404a8ef47f9" + }, + "originalKey": "df06d02b-95cb-4d34-aac6-773231f1a129", + "outputsInitialized": false, + "showInput": false }, "source": [ "### 3B. Using an arbitrary BoTorch `AcquisitionFunction` in Ax" @@ -346,11 +566,18 @@ }, { "cell_type": "markdown", - "id": "surrounded-denial", "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "id": "surrounded-denial", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "e3f0c788-2131-4116-9518-4ae7daeb991f", + "showInput": false + }, + "originalKey": "d4861847-b757-4fcd-9f35-ba258080812c", + "outputsInitialized": false, "showInput": false }, "source": [ @@ -363,14 +590,36 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "interested-search", + "execution_count": 9, "metadata": { - "code_folding": [], - "hidden_ranges": [], - "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + "collapsed": false, + "executionStartTime": 1730390510990, + "executionStopTime": 1730390511184, + "id": "interested-search", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "code_folding": [], + "hidden_ranges": [], + "originalKey": "6967ce3e-929b-4d9a-8cd1-72bf94f0be3a" + }, + "originalKey": "99d72703-5c7b-419d-b400-e2bfa94cf093", + "outputsInitialized": true, + "requestMsgId": "99d72703-5c7b-419d-b400-e2bfa94cf093", + "serverExecutionDuration": 5.7275110157207 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "BoTorchModel" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from botorch.acquisition.acquisition import AcquisitionFunction\n", "from botorch.acquisition.input_constructors import acqf_input_constructor, MaybeDict\n", @@ -409,9 +658,16 @@ }, { "cell_type": "markdown", - "id": "metallic-imaging", "metadata": { - "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + "id": "metallic-imaging", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "29256ab1-f214-4604-a423-4c7b4b36baa0" + }, + "originalKey": "b057722d-b8ca-47dd-b2c8-1ff4a71c4863", + "outputsInitialized": false, + "showInput": false }, "source": [ "See section 2A for combining the resulting `Surrogate` instance and `Acquisition` type into a `BoTorchModel`. You can also leverage `Models.BOTORCH_MODULAR` for ease of use; more on it in section 4 below or in section 1 quick-start example." @@ -419,9 +675,16 @@ }, { "cell_type": "markdown", - "id": "descending-australian", "metadata": { - "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + "id": "descending-australian", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "1d15082f-1df7-4cdb-958b-300483eb7808" + }, + "originalKey": "a7406f13-1468-487d-ac5e-7d2a45394850", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 4. Using `Models.BOTORCH_MODULAR` \n", @@ -433,12 +696,41 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "attached-border", + "execution_count": 10, "metadata": { - "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + "collapsed": false, + "executionStartTime": 1730390512519, + "executionStopTime": 1730390512922, + "id": "attached-border", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "385b2f30-fd86-4d88-8784-f238ea8a6abb" + }, + "originalKey": "144dffff-ee7c-4195-ad09-6df9b506a26e", + "outputsInitialized": true, + "requestMsgId": "144dffff-ee7c-4195-ad09-6df9b506a26e", + "serverExecutionDuration": 226.21790104313 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:01:52] ax.modelbridge.transforms.standardize_y: Outcome branin is constant, within tolerance.\n" + ] + }, + { + "data": { + "text/plain": [ + "GeneratorRun(1 arms, total weight 1.0)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", " experiment=experiment,\n", @@ -449,33 +741,84 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "powerful-gamma", + "execution_count": 11, "metadata": { - "originalKey": "89930a31-e058-434b-b587-181931e247b6" + "collapsed": false, + "executionStartTime": 1730390513113, + "executionStopTime": 1730390513298, + "id": "powerful-gamma", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "89930a31-e058-434b-b587-181931e247b6" + }, + "originalKey": "9be9b38b-0b88-429b-8fed-ad7956b2819b", + "outputsInitialized": true, + "requestMsgId": "9be9b38b-0b88-429b-8fed-ad7956b2819b", + "serverExecutionDuration": 2.6019790093414 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "botorch.acquisition.logei.qLogNoisyExpectedImprovement" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model_bridge_with_GPEI.model.botorch_acqf_class" ] }, { "cell_type": "code", - "execution_count": null, - "id": "improved-replication", + "execution_count": 17, "metadata": { - "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + "collapsed": false, + "executionStartTime": 1730390632933, + "executionStopTime": 1730390633141, + "id": "improved-replication", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "f9a9cb14-20c3-4e1d-93a3-6a35c281ae01" + }, + "originalKey": "5f1cf672-9ed2-4d09-b7c0-561c0e696cdd", + "outputsInitialized": true, + "requestMsgId": "5f1cf672-9ed2-4d09-b7c0-561c0e696cdd", + "serverExecutionDuration": 3.5180210252292 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "botorch.models.gp_regression.SingleTaskGP" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model_bridge_with_GPEI.model.surrogate.botorch_model_class" + "model_bridge_with_GPEI.model.surrogate.model.__class__" ] }, { "cell_type": "markdown", - "id": "connected-sheet", "metadata": { - "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + "id": "connected-sheet", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "8b6a9ddc-d2d2-4cd5-a6a8-820113f78262" + }, + "originalKey": "f5c0adbd-00a6-428d-810f-1e7ed0954b08", + "outputsInitialized": false, + "showInput": false }, "source": [ "We can use the same `Models.BOTORCH_MODULAR` to set up a model for multi-objective optimization:" @@ -483,12 +826,55 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "documentary-jurisdiction", + "execution_count": 18, "metadata": { - "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + "collapsed": false, + "executionStartTime": 1730390633980, + "executionStopTime": 1730390634698, + "id": "documentary-jurisdiction", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "8001de33-d9d9-4888-a5d1-7a59ebeccfd5" + }, + "originalKey": "cc9d8f7b-fcb0-4c75-9217-b898e3b8fb1b", + "outputsInitialized": true, + "requestMsgId": "cc9d8f7b-fcb0-4c75-9217-b898e3b8fb1b", + "serverExecutionDuration": 532.64761995524 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:03:54] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:03:54] ax.modelbridge.transforms.standardize_y: Outcome branin_a is constant, within tolerance.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:03:54] ax.modelbridge.transforms.standardize_y: Outcome branin_b is constant, within tolerance.\n" + ] + }, + { + "data": { + "text/plain": [ + "GeneratorRun(1 arms, total weight 1.0)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model_bridge_with_EHVI = Models.BOTORCH_MODULAR(\n", " experiment=get_branin_experiment_with_multi_objective(\n", @@ -501,33 +887,84 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "changed-maintenance", + "execution_count": 19, "metadata": { - "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + "collapsed": false, + "executionStartTime": 1730390634821, + "executionStopTime": 1730390635027, + "id": "changed-maintenance", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "dcfdbecc-4a9a-49ac-ad55-0bc04b2ec566" + }, + "originalKey": "078c8a35-1f5f-42a8-8e20-48caa861c878", + "outputsInitialized": true, + "requestMsgId": "078c8a35-1f5f-42a8-8e20-48caa861c878", + "serverExecutionDuration": 2.9336949810386 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "botorch.acquisition.multi_objective.logei.qLogNoisyExpectedHypervolumeImprovement" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model_bridge_with_EHVI.model.botorch_acqf_class" ] }, { "cell_type": "code", - "execution_count": null, - "id": "operating-shelf", + "execution_count": 21, "metadata": { - "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + "collapsed": false, + "executionStartTime": 1730390643428, + "executionStopTime": 1730390643623, + "id": "operating-shelf", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "16727a51-337d-4715-bf51-9cb6637a950f" + }, + "originalKey": "f17468dc-0387-46b5-bd37-932fa19e5660", + "outputsInitialized": true, + "requestMsgId": "f17468dc-0387-46b5-bd37-932fa19e5660", + "serverExecutionDuration": 3.0451429774985 }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "botorch.models.gp_regression.SingleTaskGP" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model_bridge_with_EHVI.model.surrogate.botorch_model_class" + "model_bridge_with_EHVI.model.surrogate.model.__class__" ] }, { "cell_type": "markdown", - "id": "fatal-butterfly", "metadata": { - "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + "id": "fatal-butterfly", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "5c64eecc-5ce5-4907-bbcc-5b3cbf4358ae" + }, + "originalKey": "3ad7c4a7-fe19-44ad-938d-1be4f8b09bfb", + "outputsInitialized": false, + "showInput": false }, "source": [ "Furthermore, the quick-start example at the top of this tutorial shows how to specify surrogate and acquisition subcomponents to `Models.BOTORCH_MODULAR`. " @@ -535,9 +972,16 @@ }, { "cell_type": "markdown", - "id": "hearing-interface", "metadata": { - "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + "id": "hearing-interface", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "a0163432-f0ca-4582-ad84-16c77c99f20b" + }, + "originalKey": "44adf1ce-6d3e-455d-b53c-32d3c42a843f", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 5. Utilizing `BoTorchModel` in generation strategies\n", @@ -549,10 +993,21 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "received-registration", + "execution_count": 22, "metadata": { - "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + "collapsed": false, + "executionStartTime": 1730390656119, + "executionStopTime": 1730390656318, + "id": "received-registration", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "f7eabbcf-607c-4bed-9a0e-6ac6e8b04350" + }, + "originalKey": "fd76a931-1e1c-4a6c-85ae-07b95b7930a0", + "outputsInitialized": true, + "requestMsgId": "fd76a931-1e1c-4a6c-85ae-07b95b7930a0", + "serverExecutionDuration": 2.3889989824966 }, "outputs": [], "source": [ @@ -586,9 +1041,16 @@ }, { "cell_type": "markdown", - "id": "logical-windsor", "metadata": { - "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + "id": "logical-windsor", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "212c4543-220e-4605-8f72-5f86cf52f722" + }, + "originalKey": "ba3783ee-3d88-4e44-ad07-77de3c50f84d", + "outputsInitialized": false, + "showInput": false }, "source": [ "Set up an experiment and generate 10 trials in it, adding synthetic data to experiment after each one:" @@ -596,12 +1058,41 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "viral-cheese", + "execution_count": 23, "metadata": { - "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + "collapsed": false, + "executionStartTime": 1730390657103, + "executionStopTime": 1730390657356, + "id": "viral-cheese", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "30cfcdd7-721d-4f89-b851-7a94140dfad6" + }, + "originalKey": "af97b6c0-a58f-48b5-8275-8587335e6202", + "outputsInitialized": true, + "requestMsgId": "af97b6c0-a58f-48b5-8275-8587335e6202", + "serverExecutionDuration": 3.7560370401479 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 10-31 09:04:17] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False\n" + ] + }, + { + "data": { + "text/plain": [ + "SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "experiment = get_branin_experiment(minimize=True)\n", "\n", @@ -611,9 +1102,16 @@ }, { "cell_type": "markdown", - "id": "incident-newspaper", "metadata": { - "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + "id": "incident-newspaper", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "2807d7ce-8a6b-423c-b5f5-32edba09c78e" + }, + "originalKey": "df2e90f5-4132-4d87-989b-e6d47c748ddc", + "outputsInitialized": false, + "showInput": false }, "source": [ "## 5a. Specifying `pending_observations`\n", @@ -624,12 +1122,70 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "casual-spread", + "execution_count": 24, "metadata": { - "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + "collapsed": false, + "executionStartTime": 1730390658427, + "executionStopTime": 1730390661681, + "id": "casual-spread", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "58aafd65-a366-4b66-a1b1-31b207037a2e" + }, + "originalKey": "f3bf846c-ab72-46b2-9266-27df704caddf", + "outputsInitialized": true, + "requestMsgId": "f3bf846c-ab72-46b2-9266-27df704caddf", + "serverExecutionDuration": 3052.0362380194 }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed trial #0, suggested by Sobol.\n", + "Completed trial #1, suggested by Sobol.\n", + "Completed trial #2, suggested by Sobol.\n", + "Completed trial #3, suggested by Sobol.\n", + "Completed trial #4, suggested by Sobol.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed trial #5, suggested by BoTorch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed trial #6, suggested by BoTorch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed trial #7, suggested by BoTorch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed trial #8, suggested by BoTorch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Completed trial #9, suggested by BoTorch.\n" + ] + } + ], "source": [ "for _ in range(10):\n", " # Produce a new generator run and attach it to experiment as a trial\n", @@ -652,9 +1208,16 @@ }, { "cell_type": "markdown", - "id": "circular-vermont", "metadata": { - "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + "id": "circular-vermont", + "isAgentGenerated": false, + "language": "markdown", + "metadata": { + "originalKey": "9d3b86bf-b691-4315-8b8f-60504b37818c" + }, + "originalKey": "6a78ef13-fbaa-4cae-934b-d57f5807fe25", + "outputsInitialized": false, + "showInput": false }, "source": [ "Now we examine the experiment and observe the trials that were added to it and produced by the generation strategy:" @@ -662,21 +1225,202 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "significant-particular", + "execution_count": 25, "metadata": { - "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + "collapsed": false, + "executionStartTime": 1730390659779, + "executionStopTime": 1730390661743, + "id": "significant-particular", + "isAgentGenerated": false, + "language": "python", + "metadata": { + "originalKey": "ca12913d-e3fd-4617-a247-e3432665bac1" + }, + "originalKey": "112e7db1-a97f-4d03-8d5c-f9090a85b486", + "outputsInitialized": true, + "requestMsgId": "112e7db1-a97f-4d03-8d5c-f9090a85b486", + "serverExecutionDuration": 46.194189984817 }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING 10-31 09:04:21] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + " | trial_index | \n", + "arm_name | \n", + "trial_status | \n", + "generation_method | \n", + "branin | \n", + "x1 | \n", + "x2 | \n", + "
---|---|---|---|---|---|---|---|
0 | \n", + "0 | \n", + "0_0 | \n", + "COMPLETED | \n", + "Sobol | \n", + "14.843151 | \n", + "-2.583597 | \n", + "14.578224 | \n", + "
1 | \n", + "1 | \n", + "1_0 | \n", + "COMPLETED | \n", + "Sobol | \n", + "33.498612 | \n", + "6.451221 | \n", + "4.854893 | \n", + "
2 | \n", + "2 | \n", + "2_0 | \n", + "COMPLETED | \n", + "Sobol | \n", + "60.273136 | \n", + "3.505221 | \n", + "9.705736 | \n", + "
3 | \n", + "3 | \n", + "3_0 | \n", + "COMPLETED | \n", + "Sobol | \n", + "8.794582 | \n", + "2.233962 | \n", + "0.919963 | \n", + "
4 | \n", + "4 | \n", + "4_0 | \n", + "COMPLETED | \n", + "Sobol | \n", + "22.418533 | \n", + "-0.583571 | \n", + "9.071722 | \n", + "
5 | \n", + "5 | \n", + "5_0 | \n", + "COMPLETED | \n", + "BoTorch | \n", + "50.414462 | \n", + "0.196160 | \n", + "0.125324 | \n", + "
6 | \n", + "6 | \n", + "6_0 | \n", + "COMPLETED | \n", + "BoTorch | \n", + "6.743681 | \n", + "2.906018 | \n", + "0.000000 | \n", + "
7 | \n", + "7 | \n", + "7_0 | \n", + "COMPLETED | \n", + "BoTorch | \n", + "1.930223 | \n", + "-2.664758 | \n", + "10.479278 | \n", + "
8 | \n", + "8 | \n", + "8_0 | \n", + "COMPLETED | \n", + "BoTorch | \n", + "14.314835 | \n", + "-3.770309 | \n", + "10.361267 | \n", + "
9 | \n", + "9 | \n", + "9_0 | \n", + "COMPLETED | \n", + "BoTorch | \n", + "11.278202 | \n", + "-2.514926 | \n", + "7.810419 | \n", + "