From 7039fc7772981b7dbe0102297ed9c7cc1e34fc6a Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 6 Oct 2023 10:23:16 -0700 Subject: [PATCH] MBM Surrogate: Clean up model fitting behavior and **kwargs usage (#1882) Summary: Kwargs: These were originally added to support passing around additional kwargs in subclasses that no-longer exist. They later silently took on the role of carrying the kwargs that gets passed down to model input constructors. The argument name has been updated with added docstring explaining what these do. These are now used to update `Surrogate.model_options` and passed to the input constructors from there. Model fitting clean up: We had a bunch of duplicate logic for constructing the models, left over from the merger of `Surrogate` & `ListSurrogate`, which led to several bugs in the past and made the code much harder to maintain and review. This diff deduplicates and simplifies the model fitting logic. Differential Revision: D49707895 --- ax/modelbridge/tests/test_registry.py | 8 +- ax/models/torch/botorch_modular/model.py | 32 +- ax/models/torch/botorch_modular/surrogate.py | 362 ++++++------------- ax/models/torch/botorch_modular/utils.py | 10 +- ax/models/torch/tests/test_acquisition.py | 2 +- ax/models/torch/tests/test_model.py | 1 + ax/models/torch/tests/test_surrogate.py | 69 ++-- ax/models/torch/tests/test_utils.py | 31 +- 8 files changed, 173 insertions(+), 342 deletions(-) diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index ca587c13cd1..5ad29d9fa4f 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -47,7 +47,7 @@ from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP from botorch.models.gp_regression import FixedNoiseGP from botorch.models.model_list_gp_regression import ModelListGP -from botorch.models.multitask import FixedNoiseMultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.utils.types import DEFAULT from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.scale_kernel import ScaleKernel @@ -475,7 +475,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: if use_saas else [ Surrogate( - botorch_model_class=FixedNoiseMultiTaskGP, + botorch_model_class=MultiTaskGP, mll_class=ExactMarginalLogLikelihood, covar_module_class=ScaleMaternKernel, covar_module_options={ @@ -514,9 +514,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: for i in range(len(models)): self.assertIsInstance( models[i], - SaasFullyBayesianMultiTaskGP - if use_saas - else FixedNoiseMultiTaskGP, + SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, ) if use_saas is False: self.assertIsInstance(models[i].covar_module, ScaleKernel) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 926c5c9ad9d..ee80a474000 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -24,7 +24,6 @@ from ax.models.torch.botorch_modular.utils import ( choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, - convert_to_block_design, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -33,7 +32,6 @@ from ax.utils.common.docutils import copy_doc from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.models import ModelList from botorch.models.deterministic import FixedSingleSampleModel from botorch.models.model import Model from botorch.models.transforms.input import InputTransform @@ -247,7 +245,7 @@ def fit( # state dict by surrogate label state_dicts: Optional[Mapping[str, Dict[str, Tensor]]] = None, refit: bool = True, - **kwargs: Any, + **additional_model_inputs: Any, ) -> None: """Fit model to m outcomes. @@ -264,6 +262,8 @@ def fit( surrogate_specs. If using a single, pre-instantiated model use `Keys.ONLY_SURROGATE. refit: Whether to re-optimize model parameters. + additional_model_inputs: Additional kwargs to pass to the + model input constructor in ``Surrogate.fit``. """ if len(datasets) != len(metric_names): @@ -288,7 +288,7 @@ def fit( if state_dicts else None, refit=refit, - **kwargs, + additional_model_inputs=additional_model_inputs, ) return @@ -340,20 +340,6 @@ def fit( datasets_by_metric_name[metric_name] for metric_name in subset_metric_names ] - if ( - len(subset_datasets) > 1 - # if Surrogate's model is none a ModelList will be autoset - and surrogate._model is not None - and not isinstance(surrogate.model, ModelList) - ): - # Note: If the datasets do not confirm to a block design then this - # will filter the data and drop observations to make sure that it does. - # This can happen e.g. if only some metrics are observed at some points - subset_datasets, metric_names = convert_to_block_design( - datasets=subset_datasets, - metric_names=metric_names, - force=True, - ) surrogate.fit( datasets=subset_datasets, @@ -362,7 +348,7 @@ def fit( candidate_metadata=candidate_metadata, state_dict=(state_dicts or {}).get(label), refit=refit, - **kwargs, + additional_model_inputs=additional_model_inputs, ) @copy_doc(TorchModel.update) @@ -372,7 +358,7 @@ def update( metric_names: List[str], search_space_digest: SearchSpaceDigest, candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, - **kwargs: Any, + **additional_model_inputs: Any, ) -> None: if len(self.surrogates) == 0: raise UnsupportedError("Cannot update model that has not been fitted.") @@ -417,7 +403,7 @@ def update( candidate_metadata=candidate_metadata, state_dict=state_dict, refit=self.refit_on_update, - **kwargs, + additional_model_inputs=additional_model_inputs, ) @single_surrogate_only @@ -536,7 +522,7 @@ def cross_validate( metric_names: List[str], X_test: Tensor, search_space_digest: SearchSpaceDigest, - **kwargs: Any, + **additional_model_inputs: Any, ) -> Tuple[Tensor, Tensor]: # Will fail if metric_names exist across multiple models surrogate_labels = ( @@ -589,7 +575,7 @@ def cross_validate( search_space_digest=search_space_digest, state_dicts=state_dicts, refit=self.refit_on_cv, - **kwargs, + **additional_model_inputs, ) X_test_prediction = self.predict_from_surrogate( surrogate_label=surrogate_label, X=X_test diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index e1ad27c322f..b7122373dbb 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -6,16 +6,14 @@ from __future__ import annotations -import dataclasses - import inspect import warnings from copy import deepcopy from logging import Logger -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch -from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest +from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError from ax.models.model_utils import best_in_sample_point @@ -28,7 +26,6 @@ from ax.models.torch.botorch_modular.input_constructors.outcome_transform import ( outcome_transform_argparse, ) - from ax.models.torch.botorch_modular.utils import ( choose_model_class, convert_to_block_design, @@ -55,7 +52,6 @@ InputTransform, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform - from botorch.utils.datasets import RankingDataset, SupervisedDataset from gpytorch.kernels import Kernel from gpytorch.likelihoods.likelihood import Likelihood @@ -126,20 +122,22 @@ class string names and the values are dictionaries of input transform Set to false to fit individual models to each metric in a loop. """ + # These attributes are instantiated in __init__. botorch_model_class: Optional[Type[Model]] model_options: Dict[str, Any] mll_class: Type[MarginalLogLikelihood] mll_options: Dict[str, Any] - outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] = None - outcome_transform_options: Optional[Dict[str, Dict[str, Any]]] = None - input_transform_classes: Optional[List[Type[InputTransform]]] = None - input_transform_options: Optional[Dict[str, Dict[str, Any]]] = None - covar_module_class: Optional[Type[Kernel]] = None + outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] + outcome_transform_options: Dict[str, Dict[str, Any]] + input_transform_classes: Optional[List[Type[InputTransform]]] + input_transform_options: Dict[str, Dict[str, Any]] + covar_module_class: Optional[Type[Kernel]] covar_module_options: Dict[str, Any] - likelihood_class: Optional[Type[Likelihood]] = None + likelihood_class: Optional[Type[Likelihood]] likelihood_options: Dict[str, Any] - allow_batched_models: bool = True + allow_batched_models: bool + # These are later updated during model fitting. _training_data: Optional[List[SupervisedDataset]] = None _outcomes: Optional[List[str]] = None _model: Optional[Model] = None @@ -250,7 +248,6 @@ def construct( datasets: List[SupervisedDataset], metric_names: List[str], search_space_digest: SearchSpaceDigest, - **kwargs: Any, ) -> None: """Constructs the underlying BoTorch ``Model`` using the training data. @@ -261,11 +258,7 @@ def construct( corresponding to the i-th dataset. search_space_digest: Information about the search space used for inferring suitable botorch model class. - **kwargs: Optional keyword arguments, expects any of: - - "fidelity_features": Indices of columns in X that represent - fidelity. """ - if self._constructed_manually: logger.warning("Reconstructing a manually constructed `Model`.") @@ -277,41 +270,40 @@ def construct( search_space_digest=not_none(search_space_digest), ) - if use_model_list( + should_use_model_list = use_model_list( datasets=datasets, botorch_model_class=botorch_model_class, allow_batched_models=self.allow_batched_models, - ): - self._construct_model_list( + ) + + if not should_use_model_list and len(datasets) > 1: + datasets, metric_names = convert_to_block_design( datasets=datasets, metric_names=metric_names, - search_space_digest=search_space_digest, - **kwargs, + force=True, ) - else: - if self.botorch_model_class is None: - self.botorch_model_class = botorch_model_class - - if len(datasets) > 1: - datasets, metric_names = convert_to_block_design( - datasets=datasets, - metric_names=metric_names, - force=True, - ) - kwargs["metric_names"] = metric_names + self._training_data = datasets - self._construct_model( - dataset=datasets[0], + models = [] + for dataset in datasets: + model = self._construct_model( + dataset=dataset, search_space_digest=search_space_digest, - **kwargs, + botorch_model_class=botorch_model_class, ) + models.append(model) + + if should_use_model_list: + self._model = ModelListGP(*models) + else: + self._model = models[0] def _construct_model( self, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - **kwargs: Any, - ) -> None: + botorch_model_class: Type[Model], + ) -> Model: """Constructs the underlying BoTorch ``Model`` using the training data. Args: @@ -320,26 +312,30 @@ def _construct_model( multi-output case, where training data is formatted with just one X and concatenated Ys). search_space_digest: Search space digest used to set up model arguments. - **kwargs: Optional keyword arguments, expects any of: - - "fidelity_features": Indices of columns in X that represent - fidelity. + botorch_model_class: ``Model`` class to be used as the underlying + BoTorch model. """ - if self.botorch_model_class is None: - raise ValueError( - "botorch_model_class must be set to construct single model Surrogate." - ) - botorch_model_class = self.botorch_model_class + ( + fidelity_features, + task_feature, + categorical_features, + input_transform_classes, + input_transform_options, + ) = self._extract_construct_model_list_kwargs( + search_space_digest=search_space_digest, + ) - input_constructor_kwargs = {**self.model_options, **(kwargs or {})} + input_constructor_kwargs = { + **self.model_options, + "fidelity_features": fidelity_features, + "task_feature": task_feature, + "categorical_features": categorical_features, + } botorch_model_class_args = inspect.getfullargspec(botorch_model_class).args # Temporary workaround to allow models to consume data from - # `FixedNoiseDataset`s even if they don't accept variance observations - if ( - "train_Yvar" not in botorch_model_class_args - and dataset.Yvar is not None - and not isinstance(dataset, RankingDataset) - ): + # `FixedNoiseDataset`s even if they don't accept variance observations. + if "train_Yvar" not in botorch_model_class_args and dataset.Yvar is not None: warnings.warn( f"Provided model class {botorch_model_class} does not accept " "`train_Yvar` argument, but received dataset with `Yvar`. Ignoring " @@ -354,157 +350,50 @@ def _construct_model( outcome_names=dataset.outcome_names, ) - self._training_data = [dataset] - formatted_model_inputs = botorch_model_class.construct_inputs( training_data=dataset, **input_constructor_kwargs ) self._set_formatted_inputs( formatted_model_inputs=formatted_model_inputs, inputs=[ - [ + ( "covar_module", self.covar_module_class, self.covar_module_options, - None, - ], - ["likelihood", self.likelihood_class, self.likelihood_options, None], - [ + ), + ("likelihood", self.likelihood_class, self.likelihood_options), + ( "outcome_transform", self.outcome_transform_classes, self.outcome_transform_options, - None, - ], - [ + ), + ( "input_transform", - self.input_transform_classes, - self.input_transform_options, - None, - ], + input_transform_classes, + deepcopy(input_transform_options), + ), ], dataset=dataset, search_space_digest=search_space_digest, botorch_model_class_args=botorch_model_class_args, ) # pyre-ignore [45] - self._model = botorch_model_class(**formatted_model_inputs) - - def _construct_model_list( - self, - datasets: List[SupervisedDataset], - metric_names: Iterable[str], - search_space_digest: SearchSpaceDigest, - **kwargs: Any, - ) -> None: - """Constructs the underlying BoTorch ``Model`` using the training data. - - Args: - datasets: List of ``SupervisedDataset`` for the submodels of - ``ModelListGP``. Each training data is for one outcome, and the order - of outcomes should match the order of metrics in ``metric_names`` - argument. - metric_names: Names of metrics, in the same order as datasets (so if - datasets is ``[ds_A, ds_B]``, the metrics are ``["A" and "B"]``). - These are used to match training data with correct submodels of - ``ModelListGP``. - search_space_digest: SearchSpaceDigest must be provided if no - botorch_submodel_class is provided so the appropriate botorch model - class can be automatically selected. - - **kwargs: Keyword arguments, accepts: - - ``fidelity_features``: Indices of columns in X that represent - fidelity features. - - ``task_features``: Indices of columns in X that represent tasks. - - ``categorical_features``: Indices of columns in X that represent - categorical features. - - ``robust_digest``: An optional `RobustSearchSpaceDigest` that carries - additional attributes if using a `RobustSearchSpace`. - """ - - self._training_data = datasets - - ( - fidelity_features, - task_feature, - submodel_input_transform_classes, - submodel_input_transform_options, - ) = self._extract_construct_model_list_kwargs( - fidelity_features=kwargs.pop(Keys.FIDELITY_FEATURES, []), - task_features=kwargs.pop(Keys.TASK_FEATURES, []), - robust_digest=kwargs.pop("robust_digest", None), - ) - - input_constructor_kwargs = {**self.model_options, **(kwargs or {})} - - submodels = [] - for m, dataset in zip(metric_names, datasets): - model_cls = self.botorch_model_class or choose_model_class( - datasets=[dataset], search_space_digest=not_none(search_space_digest) - ) - - if self._outcomes is not None and m not in self._outcomes: - logger.warning(f"Metric {m} not in training data.") - continue - formatted_model_inputs = model_cls.construct_inputs( - training_data=dataset, - fidelity_features=fidelity_features, - task_feature=task_feature, - **input_constructor_kwargs, - ) - # Add input / outcome transforms. - # TODO: The use of `inspect` here is not ideal. We should find a better - # way to filter the arguments. See the comment in `Surrogate.construct` - # regarding potential use of a `ModelFactory` in the future. - model_cls_args = inspect.getfullargspec(model_cls).args - self._set_formatted_inputs( - formatted_model_inputs=formatted_model_inputs, - inputs=[ - [ - "covar_module", - self.covar_module_class, - deepcopy(self.covar_module_options), - None, - ], - [ - "likelihood", - self.likelihood_class, - deepcopy(self.likelihood_options), - None, - ], - [ - "outcome_transform", - self.outcome_transform_classes, - deepcopy(self.outcome_transform_options), - None, - ], - [ - "input_transform", - submodel_input_transform_classes, - deepcopy(submodel_input_transform_options), - None, - ], - ], - dataset=dataset, - search_space_digest=search_space_digest, - botorch_model_class_args=model_cls_args, - ) - # pyre-ignore[45]: Py raises informative error if model is abstract. - submodels.append(model_cls(**formatted_model_inputs)) - - self._model = ModelListGP(*submodels) + return botorch_model_class(**formatted_model_inputs) def _set_formatted_inputs( self, formatted_model_inputs: Dict[str, Any], - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - inputs: List[List[Any]], + # pyre-ignore [2] The proper hint for the second arg is Union[None, + # Type[Kernel], Type[Likelihood], List[Type[OutcomeTransform]], + # List[Type[InputTransform]]]. Keeping it as Any saves us from a + # bunch of checked_cast calls within the for loop. + inputs: List[Tuple[str, Any, Dict[str, Any]]], dataset: SupervisedDataset, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - botorch_model_class_args: Any, + botorch_model_class_args: List[str], search_space_digest: SearchSpaceDigest, ) -> None: - for input_name, input_class, input_options, input_object in inputs: - if input_class is None and input_object is None: + for input_name, input_class, input_options in inputs: + if input_class is None: continue if input_name not in botorch_model_class_args: # TODO: We currently only pass in `covar_module` and `likelihood` @@ -515,68 +404,38 @@ def _set_formatted_inputs( f"The BoTorch model class {self.botorch_model_class} does not " f"support the input {input_name}." ) - if input_class is not None and input_object is not None: - raise RuntimeError(f"Got both a class and an object for {input_name}.") - if input_class is not None: - input_options = input_options or {} - - if input_name == "covar_module": - covar_module_with_defaults = covar_module_argparse( - input_class, - dataset=dataset, - botorch_model_class=self.botorch_model_class, - **input_options, - ) - - formatted_model_inputs[input_name] = input_class( - **covar_module_with_defaults - ) - - elif input_name == "input_transform": - - formatted_model_inputs[ - input_name - ] = self._make_botorch_input_transform( - input_classes=input_class, - input_options=input_options, - dataset=dataset, - search_space_digest=search_space_digest, - ) - - elif input_name == "outcome_transform": - - formatted_model_inputs[ - input_name - ] = self._make_botorch_outcome_transform( - input_classes=input_class, - input_options=input_options, - dataset=dataset, - ) - else: - formatted_model_inputs[input_name] = input_class(**input_options) - - else: - formatted_model_inputs[input_name] = input_object + input_options = deepcopy(input_options) or {} + + if input_name == "covar_module": + covar_module_with_defaults = covar_module_argparse( + input_class, + dataset=dataset, + botorch_model_class=self.botorch_model_class, + **input_options, + ) - # Construct input perturbation if doing robust optimization. - robust_digest = search_space_digest.robust_digest - if robust_digest is not None: + formatted_model_inputs[input_name] = input_class( + **covar_module_with_defaults + ) - perturbation = self._make_botorch_input_transform( - input_classes=[InputPerturbation], - dataset=dataset, - search_space_digest=search_space_digest, - input_options={}, - ) + elif input_name == "input_transform": + formatted_model_inputs[input_name] = self._make_botorch_input_transform( + input_classes=input_class, + input_options=input_options, + dataset=dataset, + search_space_digest=search_space_digest, + ) - if formatted_model_inputs.get("input_transform") is not None: - # TODO: Support mixing with user supplied transforms. - raise NotImplementedError( - "User supplied input transforms are not supported " - "in robust optimization." + elif input_name == "outcome_transform": + formatted_model_inputs[ + input_name + ] = self._make_botorch_outcome_transform( + input_classes=input_class, + input_options=input_options, + dataset=dataset, ) else: - formatted_model_inputs["input_transform"] = perturbation + formatted_model_inputs[input_name] = input_class(**input_options) def _make_botorch_input_transform( self, @@ -675,7 +534,7 @@ def fit( candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, - **kwargs: Any, + additional_model_inputs: Optional[Dict[str, Any]] = None, ) -> None: """Fits the underlying BoTorch ``Model`` to ``m`` outcomes. @@ -706,7 +565,10 @@ def fit( the order corresponding to the Xs. state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters. + additional_model_inputs: Additional kwargs to pass to the + model input constructor. """ + self.model_options.update(additional_model_inputs or {}) if self._constructed_manually: logger.debug( "For manually constructed surrogates (via `Surrogate.from_botorch`), " @@ -714,13 +576,10 @@ def fit( "its parameters if `refit=True`." ) else: - _kwargs = dataclasses.asdict(search_space_digest) - _kwargs.update(kwargs) self.construct( datasets=datasets, metric_names=metric_names, search_space_digest=search_space_digest, - **_kwargs, ) self._outcomes = metric_names @@ -827,7 +686,7 @@ def update( candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, - **kwargs: Any, + additional_model_inputs: Optional[Dict[str, Any]] = None, ) -> None: """Updates the surrogate model with new data. In the base ``Surrogate``, just calls ``fit`` after checking that this surrogate was not created @@ -848,6 +707,8 @@ def update( state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters or just set the training data used for interence to new training data. + additional_model_inputs: Additional kwargs to pass to the + model input constructor. """ # NOTE: In the future, could have `incremental` kwarg, in which case # `training_data` could contain just the new data. @@ -865,7 +726,7 @@ def update( candidate_metadata=candidate_metadata, state_dict=state_dict, refit=refit, - **kwargs, + additional_model_inputs=additional_model_inputs, ) def pareto_frontier(self) -> Tuple[Tensor, Tensor]: @@ -910,17 +771,16 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]: } def _extract_construct_model_list_kwargs( - self, - fidelity_features: Sequence[int], - task_features: Sequence[int], - robust_digest: Optional[RobustSearchSpaceDigest] = None, + self, search_space_digest: SearchSpaceDigest ) -> Tuple[ List[int], Optional[int], + List[int], Optional[List[Type[InputTransform]]], - Optional[Dict[str, Dict["str", Any]]], + Dict[str, Dict[str, Any]], ]: - + fidelity_features = search_space_digest.fidelity_features + task_features = search_space_digest.task_features if len(fidelity_features) > 0 and len(task_features) > 0: raise NotImplementedError( "Multi-Fidelity GP models with task_features are " @@ -938,8 +798,7 @@ def _extract_construct_model_list_kwargs( # Construct input perturbation if doing robust optimization. # NOTE: Doing this here rather than in `_set_formatted_inputs` to make sure # we use the same perturbations for each sub-model. - if robust_digest is not None: - + if (robust_digest := search_space_digest.robust_digest) is not None: submodel_input_transform_options = { "InputPerturbation": input_transform_argparse( InputTransform, @@ -964,8 +823,9 @@ def _extract_construct_model_list_kwargs( submodel_input_transform_options = self.input_transform_options return ( - list(fidelity_features), + fidelity_features, task_feature, + search_space_digest.categorical_features, submodel_input_transform_classes, submodel_input_transform_options, ) diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 2223a1467f1..2d6f73f3915 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -31,7 +31,7 @@ from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel from botorch.models.model import Model, ModelList -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ChainedInputTransform from botorch.utils.datasets import SupervisedDataset @@ -108,11 +108,9 @@ def choose_model_class( "errors. Variances should all be specified, or none should be." ) - # Multi-task cases (when `task_features` specified). - if search_space_digest.task_features and all_inferred: - model_class = MultiTaskGP # Unknown observation noise. - elif search_space_digest.task_features: - model_class = FixedNoiseMultiTaskGP # Known observation noise. + # Multi-task case (when `task_features` is specified). + if search_space_digest.task_features: + model_class = MultiTaskGP # Single-task multi-fidelity cases. elif search_space_digest.fidelity_features and all_inferred: diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index 58b6d5db3ce..29a781cb517 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -117,12 +117,12 @@ def setUp(self) -> None: self.search_space_digest = SearchSpaceDigest( feature_names=self.feature_names, bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)], + fidelity_features=self.fidelity_features, target_values={2: 1.0}, ) self.surrogate.construct( datasets=self.training_data, metric_names=self.metric_names, - fidelity_features=self.fidelity_features, search_space_digest=SearchSpaceDigest( feature_names=self.search_space_digest.feature_names[:1], bounds=self.search_space_digest.bounds, diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 25aa9d3b6be..ec8e49aee86 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -345,6 +345,7 @@ def test_fit(self, mock_fit: Mock) -> None: candidate_metadata=self.candidate_metadata, state_dict=None, refit=True, + additional_model_inputs={}, ) # ensure that error is raised when len(metric_names) != len(datasets) with self.assertRaisesRegex( diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 48fc788c3d6..7dffd1ba2f3 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -34,7 +34,7 @@ from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.model import Model, ModelList # noqa: F401 -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood from botorch.models.transforms.input import InputPerturbation, Normalize from botorch.models.transforms.outcome import Standardize @@ -198,7 +198,6 @@ def test_copy_options(self) -> None: ) # Change the lengthscales of one model and make sure the other isn't changed models[0].covar_module.base_kernel.lengthscale += 1 - self.assertTrue( torch.allclose( model1_old_lengtscale, @@ -364,15 +363,11 @@ def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None: search_space_digest=self.search_space_digest, ) mock_construct_inputs.assert_called_with( - training_data=self.training_data[0], some_option="some_value" - ) - - # botorch_model_class must be set to construct single model Surrogate - with self.assertRaisesRegex(ValueError, "botorch_model_class must be set"): - surrogate = Surrogate() - surrogate._construct_model( - dataset=self.training_data[0], - search_space_digest=self.search_space_digest, + training_data=self.training_data[0], + some_option="some_value", + fidelity_features=[], + task_feature=None, + categorical_features=[], ) def test_construct_custom_model(self) -> None: @@ -404,16 +399,16 @@ def test_construct_custom_model(self) -> None: metric_names=self.metric_names, search_space_digest=self.search_space_digest, ) - self.assertEqual(type(surrogate._model.likelihood), GaussianLikelihood) + model = not_none(surrogate._model) + self.assertEqual(type(model.likelihood), GaussianLikelihood) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `likelihood`. - surrogate._model.likelihood.noise_covar.raw_noise_constraint, - noise_constraint, + # Checking equality of __dict__'s since Interval does not define __eq__. + model.likelihood.noise_covar.raw_noise_constraint.__dict__, + noise_constraint.__dict__, ) self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood) - self.assertEqual(type(surrogate._model.covar_module), RBFKernel) - # pyre-fixme[16]: Optional type has no attribute `covar_module`. - self.assertEqual(surrogate._model.covar_module.ard_num_dims, 1) + self.assertEqual(type(model.covar_module), RBFKernel) + self.assertEqual(model.covar_module.ard_num_dims, 1) @patch( f"{CURRENT_PATH}.SaasFullyBayesianSingleTaskGP.load_state_dict", @@ -498,7 +493,6 @@ def test_predict(self, mock_predict: Mock) -> None: surrogate.construct( datasets=self.training_data, metric_names=self.metric_names, - fidelity_features=self.search_space_digest.fidelity_features, search_space_digest=self.search_space_digest, ) surrogate.predict(X=self.Xs[0]) @@ -510,7 +504,6 @@ def test_best_in_sample_point(self) -> None: surrogate.construct( datasets=self.training_data, metric_names=self.metric_names, - fidelity_features=self.search_space_digest.fidelity_features, search_space_digest=self.search_space_digest, ) # `best_in_sample_point` requires `objective_weights` @@ -637,6 +630,7 @@ def test_update( candidate_metadata=None, refit=self.refit, state_dict={"key": torch.zeros(1)}, + additional_model_inputs=None, ) # Check that the training data is correctly passed through to the @@ -799,10 +793,10 @@ def setUp(self) -> None: self.outcomes = ["outcome_1", "outcome_2"] self.mll_class = ExactMarginalLogLikelihood self.dtype = torch.float + self.task_features = [0] self.search_space_digest = SearchSpaceDigest( - feature_names=[], bounds=[], task_features=[0] + feature_names=[], bounds=[], task_features=self.task_features ) - self.task_features = [0] Xs1, Ys1, Yvars1, bounds, _, _, _ = get_torch_test_data( dtype=self.dtype, task_features=self.search_space_digest.task_features ) @@ -848,9 +842,9 @@ def setUp(self) -> None: search_space_digest=self.search_space_digest, ), } - self.botorch_model_class = FixedNoiseMultiTaskGP + self.botorch_model_class = MultiTaskGP for submodel_cls in self.botorch_submodel_class_per_outcome.values(): - self.assertEqual(submodel_cls, FixedNoiseMultiTaskGP) + self.assertEqual(submodel_cls, MultiTaskGP) self.Xs = Xs1 + Xs2 self.Ys = Ys1 + Ys2 self.Yvars = Yvars1 + Yvars2 @@ -878,7 +872,7 @@ def setUp(self) -> None: RANK: 1, } self.surrogate = Surrogate( - botorch_model_class=FixedNoiseMultiTaskGP, + botorch_model_class=MultiTaskGP, mll_class=self.mll_class, model_options=self.submodel_options_per_outcome, ) @@ -897,19 +891,21 @@ def test_init(self) -> None: self.surrogate.model @patch.object( - FixedNoiseMultiTaskGP, + MultiTaskGP, "construct_inputs", - wraps=FixedNoiseMultiTaskGP.construct_inputs, + wraps=MultiTaskGP.construct_inputs, ) def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock ) -> None: + self.surrogate.model_options.update({"output_tasks": [2]}) self.surrogate.construct( datasets=self.fixed_noise_training_data, metric_names=self.outcomes, - output_tasks=[2], - search_space_digest=self.search_space_digest, - task_features=self.task_features, + search_space_digest=dataclasses.replace( + self.search_space_digest, + task_features=self.task_features, + ), ) # Should construct inputs for MTGP twice. self.assertEqual(len(mock_MTGP_construct_inputs.call_args_list), 2) @@ -919,6 +915,7 @@ def test_construct_per_outcome_options( # `call_args` is a tuple of (args, kwargs), and we check kwargs. mock_MTGP_construct_inputs.call_args_list[idx][1], { + "categorical_features": [], "fidelity_features": [], "task_feature": self.task_features[0], "training_data": SupervisedDataset( @@ -947,10 +944,10 @@ def test_construct_per_outcome_options_no_Yvar(self, _) -> None: # Test that splitting the training data works correctly when Yvar is None. surrogate.construct( datasets=self.supervised_training_data, - task_features=self.task_features, metric_names=self.outcomes, search_space_digest=SearchSpaceDigest( feature_names=self.feature_names, + task_features=self.task_features, bounds=self.bounds, ), ) @@ -960,9 +957,9 @@ def test_construct_per_outcome_options_no_Yvar(self, _) -> None: self.assertEqual(len(not_none(surrogate._training_data)), 2) @patch.object( - FixedNoiseMultiTaskGP, + MultiTaskGP, "construct_inputs", - wraps=FixedNoiseMultiTaskGP.construct_inputs, + wraps=MultiTaskGP.construct_inputs, ) def test_construct_per_outcome_error_raises( self, mock_MTGP_construct_inputs: Mock @@ -981,10 +978,10 @@ def test_construct_per_outcome_error_raises( surrogate.construct( datasets=self.fixed_noise_training_data, metric_names=self.outcomes, - task_features=self.task_features, - fidelity_features=[1], search_space_digest=SearchSpaceDigest( feature_names=self.feature_names, + task_features=self.task_features, + fidelity_features=[1], bounds=self.bounds, ), ) @@ -995,9 +992,9 @@ def test_construct_per_outcome_error_raises( surrogate.construct( datasets=self.fixed_noise_training_data, metric_names=self.outcomes, - task_features=[0, 1], search_space_digest=SearchSpaceDigest( feature_names=self.feature_names, + task_features=[0, 1], bounds=self.bounds, ), ) diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 3e1b40f52a3..fbfe39028b7 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -38,7 +38,7 @@ ) from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.model import ModelList -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import ( ChainedInputTransform, InputPerturbation, @@ -140,26 +140,17 @@ def test_choose_model_class_task_features(self) -> None: feature_names=[], bounds=[], task_features=[1, 2] ), ) - # With fidelity features and unknown variances, use SingleTaskMultiFidelityGP. - self.assertEqual( - MultiTaskGP, - choose_model_class( - datasets=self.supervised_datasets, - search_space_digest=SearchSpaceDigest( - feature_names=[], bounds=[], task_features=[1] - ), - ), - ) - # With fidelity features and known variances, use FixedNoiseMultiFidelityGP. - self.assertEqual( - FixedNoiseMultiTaskGP, - choose_model_class( - datasets=self.fixed_noise_datasets, - search_space_digest=SearchSpaceDigest( - feature_names=[], bounds=[], task_features=[1] + # With task features use MultiTaskGP. + for datasets in (self.supervised_datasets, self.fixed_noise_datasets): + self.assertEqual( + MultiTaskGP, + choose_model_class( + datasets=datasets, + search_space_digest=SearchSpaceDigest( + feature_names=[], bounds=[], task_features=[1] + ), ), - ), - ) + ) def test_choose_model_class_discrete_features(self) -> None: # With discrete features, use MixedSingleTaskyGP.