diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index ca587c13cd1..5ad29d9fa4f 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -47,7 +47,7 @@ from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP from botorch.models.gp_regression import FixedNoiseGP from botorch.models.model_list_gp_regression import ModelListGP -from botorch.models.multitask import FixedNoiseMultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.utils.types import DEFAULT from gpytorch.kernels.matern_kernel import MaternKernel from gpytorch.kernels.scale_kernel import ScaleKernel @@ -475,7 +475,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: if use_saas else [ Surrogate( - botorch_model_class=FixedNoiseMultiTaskGP, + botorch_model_class=MultiTaskGP, mll_class=ExactMarginalLogLikelihood, covar_module_class=ScaleMaternKernel, covar_module_options={ @@ -514,9 +514,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: for i in range(len(models)): self.assertIsInstance( models[i], - SaasFullyBayesianMultiTaskGP - if use_saas - else FixedNoiseMultiTaskGP, + SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP, ) if use_saas is False: self.assertIsInstance(models[i].covar_module, ScaleKernel) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 926c5c9ad9d..ee80a474000 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -24,7 +24,6 @@ from ax.models.torch.botorch_modular.utils import ( choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, - convert_to_block_design, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -33,7 +32,6 @@ from ax.utils.common.docutils import copy_doc from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.models import ModelList from botorch.models.deterministic import FixedSingleSampleModel from botorch.models.model import Model from botorch.models.transforms.input import InputTransform @@ -247,7 +245,7 @@ def fit( # state dict by surrogate label state_dicts: Optional[Mapping[str, Dict[str, Tensor]]] = None, refit: bool = True, - **kwargs: Any, + **additional_model_inputs: Any, ) -> None: """Fit model to m outcomes. @@ -264,6 +262,8 @@ def fit( surrogate_specs. If using a single, pre-instantiated model use `Keys.ONLY_SURROGATE. refit: Whether to re-optimize model parameters. + additional_model_inputs: Additional kwargs to pass to the + model input constructor in ``Surrogate.fit``. """ if len(datasets) != len(metric_names): @@ -288,7 +288,7 @@ def fit( if state_dicts else None, refit=refit, - **kwargs, + additional_model_inputs=additional_model_inputs, ) return @@ -340,20 +340,6 @@ def fit( datasets_by_metric_name[metric_name] for metric_name in subset_metric_names ] - if ( - len(subset_datasets) > 1 - # if Surrogate's model is none a ModelList will be autoset - and surrogate._model is not None - and not isinstance(surrogate.model, ModelList) - ): - # Note: If the datasets do not confirm to a block design then this - # will filter the data and drop observations to make sure that it does. - # This can happen e.g. if only some metrics are observed at some points - subset_datasets, metric_names = convert_to_block_design( - datasets=subset_datasets, - metric_names=metric_names, - force=True, - ) surrogate.fit( datasets=subset_datasets, @@ -362,7 +348,7 @@ def fit( candidate_metadata=candidate_metadata, state_dict=(state_dicts or {}).get(label), refit=refit, - **kwargs, + additional_model_inputs=additional_model_inputs, ) @copy_doc(TorchModel.update) @@ -372,7 +358,7 @@ def update( metric_names: List[str], search_space_digest: SearchSpaceDigest, candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, - **kwargs: Any, + **additional_model_inputs: Any, ) -> None: if len(self.surrogates) == 0: raise UnsupportedError("Cannot update model that has not been fitted.") @@ -417,7 +403,7 @@ def update( candidate_metadata=candidate_metadata, state_dict=state_dict, refit=self.refit_on_update, - **kwargs, + additional_model_inputs=additional_model_inputs, ) @single_surrogate_only @@ -536,7 +522,7 @@ def cross_validate( metric_names: List[str], X_test: Tensor, search_space_digest: SearchSpaceDigest, - **kwargs: Any, + **additional_model_inputs: Any, ) -> Tuple[Tensor, Tensor]: # Will fail if metric_names exist across multiple models surrogate_labels = ( @@ -589,7 +575,7 @@ def cross_validate( search_space_digest=search_space_digest, state_dicts=state_dicts, refit=self.refit_on_cv, - **kwargs, + **additional_model_inputs, ) X_test_prediction = self.predict_from_surrogate( surrogate_label=surrogate_label, X=X_test diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index e1ad27c322f..b7122373dbb 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -6,16 +6,14 @@ from __future__ import annotations -import dataclasses - import inspect import warnings from copy import deepcopy from logging import Logger -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch -from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest +from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError from ax.models.model_utils import best_in_sample_point @@ -28,7 +26,6 @@ from ax.models.torch.botorch_modular.input_constructors.outcome_transform import ( outcome_transform_argparse, ) - from ax.models.torch.botorch_modular.utils import ( choose_model_class, convert_to_block_design, @@ -55,7 +52,6 @@ InputTransform, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform - from botorch.utils.datasets import RankingDataset, SupervisedDataset from gpytorch.kernels import Kernel from gpytorch.likelihoods.likelihood import Likelihood @@ -126,20 +122,22 @@ class string names and the values are dictionaries of input transform Set to false to fit individual models to each metric in a loop. """ + # These attributes are instantiated in __init__. botorch_model_class: Optional[Type[Model]] model_options: Dict[str, Any] mll_class: Type[MarginalLogLikelihood] mll_options: Dict[str, Any] - outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] = None - outcome_transform_options: Optional[Dict[str, Dict[str, Any]]] = None - input_transform_classes: Optional[List[Type[InputTransform]]] = None - input_transform_options: Optional[Dict[str, Dict[str, Any]]] = None - covar_module_class: Optional[Type[Kernel]] = None + outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] + outcome_transform_options: Dict[str, Dict[str, Any]] + input_transform_classes: Optional[List[Type[InputTransform]]] + input_transform_options: Dict[str, Dict[str, Any]] + covar_module_class: Optional[Type[Kernel]] covar_module_options: Dict[str, Any] - likelihood_class: Optional[Type[Likelihood]] = None + likelihood_class: Optional[Type[Likelihood]] likelihood_options: Dict[str, Any] - allow_batched_models: bool = True + allow_batched_models: bool + # These are later updated during model fitting. _training_data: Optional[List[SupervisedDataset]] = None _outcomes: Optional[List[str]] = None _model: Optional[Model] = None @@ -250,7 +248,6 @@ def construct( datasets: List[SupervisedDataset], metric_names: List[str], search_space_digest: SearchSpaceDigest, - **kwargs: Any, ) -> None: """Constructs the underlying BoTorch ``Model`` using the training data. @@ -261,11 +258,7 @@ def construct( corresponding to the i-th dataset. search_space_digest: Information about the search space used for inferring suitable botorch model class. - **kwargs: Optional keyword arguments, expects any of: - - "fidelity_features": Indices of columns in X that represent - fidelity. """ - if self._constructed_manually: logger.warning("Reconstructing a manually constructed `Model`.") @@ -277,41 +270,40 @@ def construct( search_space_digest=not_none(search_space_digest), ) - if use_model_list( + should_use_model_list = use_model_list( datasets=datasets, botorch_model_class=botorch_model_class, allow_batched_models=self.allow_batched_models, - ): - self._construct_model_list( + ) + + if not should_use_model_list and len(datasets) > 1: + datasets, metric_names = convert_to_block_design( datasets=datasets, metric_names=metric_names, - search_space_digest=search_space_digest, - **kwargs, + force=True, ) - else: - if self.botorch_model_class is None: - self.botorch_model_class = botorch_model_class - - if len(datasets) > 1: - datasets, metric_names = convert_to_block_design( - datasets=datasets, - metric_names=metric_names, - force=True, - ) - kwargs["metric_names"] = metric_names + self._training_data = datasets - self._construct_model( - dataset=datasets[0], + models = [] + for dataset in datasets: + model = self._construct_model( + dataset=dataset, search_space_digest=search_space_digest, - **kwargs, + botorch_model_class=botorch_model_class, ) + models.append(model) + + if should_use_model_list: + self._model = ModelListGP(*models) + else: + self._model = models[0] def _construct_model( self, dataset: SupervisedDataset, search_space_digest: SearchSpaceDigest, - **kwargs: Any, - ) -> None: + botorch_model_class: Type[Model], + ) -> Model: """Constructs the underlying BoTorch ``Model`` using the training data. Args: @@ -320,26 +312,30 @@ def _construct_model( multi-output case, where training data is formatted with just one X and concatenated Ys). search_space_digest: Search space digest used to set up model arguments. - **kwargs: Optional keyword arguments, expects any of: - - "fidelity_features": Indices of columns in X that represent - fidelity. + botorch_model_class: ``Model`` class to be used as the underlying + BoTorch model. """ - if self.botorch_model_class is None: - raise ValueError( - "botorch_model_class must be set to construct single model Surrogate." - ) - botorch_model_class = self.botorch_model_class + ( + fidelity_features, + task_feature, + categorical_features, + input_transform_classes, + input_transform_options, + ) = self._extract_construct_model_list_kwargs( + search_space_digest=search_space_digest, + ) - input_constructor_kwargs = {**self.model_options, **(kwargs or {})} + input_constructor_kwargs = { + **self.model_options, + "fidelity_features": fidelity_features, + "task_feature": task_feature, + "categorical_features": categorical_features, + } botorch_model_class_args = inspect.getfullargspec(botorch_model_class).args # Temporary workaround to allow models to consume data from - # `FixedNoiseDataset`s even if they don't accept variance observations - if ( - "train_Yvar" not in botorch_model_class_args - and dataset.Yvar is not None - and not isinstance(dataset, RankingDataset) - ): + # `FixedNoiseDataset`s even if they don't accept variance observations. + if "train_Yvar" not in botorch_model_class_args and dataset.Yvar is not None: warnings.warn( f"Provided model class {botorch_model_class} does not accept " "`train_Yvar` argument, but received dataset with `Yvar`. Ignoring " @@ -354,157 +350,50 @@ def _construct_model( outcome_names=dataset.outcome_names, ) - self._training_data = [dataset] - formatted_model_inputs = botorch_model_class.construct_inputs( training_data=dataset, **input_constructor_kwargs ) self._set_formatted_inputs( formatted_model_inputs=formatted_model_inputs, inputs=[ - [ + ( "covar_module", self.covar_module_class, self.covar_module_options, - None, - ], - ["likelihood", self.likelihood_class, self.likelihood_options, None], - [ + ), + ("likelihood", self.likelihood_class, self.likelihood_options), + ( "outcome_transform", self.outcome_transform_classes, self.outcome_transform_options, - None, - ], - [ + ), + ( "input_transform", - self.input_transform_classes, - self.input_transform_options, - None, - ], + input_transform_classes, + deepcopy(input_transform_options), + ), ], dataset=dataset, search_space_digest=search_space_digest, botorch_model_class_args=botorch_model_class_args, ) # pyre-ignore [45] - self._model = botorch_model_class(**formatted_model_inputs) - - def _construct_model_list( - self, - datasets: List[SupervisedDataset], - metric_names: Iterable[str], - search_space_digest: SearchSpaceDigest, - **kwargs: Any, - ) -> None: - """Constructs the underlying BoTorch ``Model`` using the training data. - - Args: - datasets: List of ``SupervisedDataset`` for the submodels of - ``ModelListGP``. Each training data is for one outcome, and the order - of outcomes should match the order of metrics in ``metric_names`` - argument. - metric_names: Names of metrics, in the same order as datasets (so if - datasets is ``[ds_A, ds_B]``, the metrics are ``["A" and "B"]``). - These are used to match training data with correct submodels of - ``ModelListGP``. - search_space_digest: SearchSpaceDigest must be provided if no - botorch_submodel_class is provided so the appropriate botorch model - class can be automatically selected. - - **kwargs: Keyword arguments, accepts: - - ``fidelity_features``: Indices of columns in X that represent - fidelity features. - - ``task_features``: Indices of columns in X that represent tasks. - - ``categorical_features``: Indices of columns in X that represent - categorical features. - - ``robust_digest``: An optional `RobustSearchSpaceDigest` that carries - additional attributes if using a `RobustSearchSpace`. - """ - - self._training_data = datasets - - ( - fidelity_features, - task_feature, - submodel_input_transform_classes, - submodel_input_transform_options, - ) = self._extract_construct_model_list_kwargs( - fidelity_features=kwargs.pop(Keys.FIDELITY_FEATURES, []), - task_features=kwargs.pop(Keys.TASK_FEATURES, []), - robust_digest=kwargs.pop("robust_digest", None), - ) - - input_constructor_kwargs = {**self.model_options, **(kwargs or {})} - - submodels = [] - for m, dataset in zip(metric_names, datasets): - model_cls = self.botorch_model_class or choose_model_class( - datasets=[dataset], search_space_digest=not_none(search_space_digest) - ) - - if self._outcomes is not None and m not in self._outcomes: - logger.warning(f"Metric {m} not in training data.") - continue - formatted_model_inputs = model_cls.construct_inputs( - training_data=dataset, - fidelity_features=fidelity_features, - task_feature=task_feature, - **input_constructor_kwargs, - ) - # Add input / outcome transforms. - # TODO: The use of `inspect` here is not ideal. We should find a better - # way to filter the arguments. See the comment in `Surrogate.construct` - # regarding potential use of a `ModelFactory` in the future. - model_cls_args = inspect.getfullargspec(model_cls).args - self._set_formatted_inputs( - formatted_model_inputs=formatted_model_inputs, - inputs=[ - [ - "covar_module", - self.covar_module_class, - deepcopy(self.covar_module_options), - None, - ], - [ - "likelihood", - self.likelihood_class, - deepcopy(self.likelihood_options), - None, - ], - [ - "outcome_transform", - self.outcome_transform_classes, - deepcopy(self.outcome_transform_options), - None, - ], - [ - "input_transform", - submodel_input_transform_classes, - deepcopy(submodel_input_transform_options), - None, - ], - ], - dataset=dataset, - search_space_digest=search_space_digest, - botorch_model_class_args=model_cls_args, - ) - # pyre-ignore[45]: Py raises informative error if model is abstract. - submodels.append(model_cls(**formatted_model_inputs)) - - self._model = ModelListGP(*submodels) + return botorch_model_class(**formatted_model_inputs) def _set_formatted_inputs( self, formatted_model_inputs: Dict[str, Any], - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - inputs: List[List[Any]], + # pyre-ignore [2] The proper hint for the second arg is Union[None, + # Type[Kernel], Type[Likelihood], List[Type[OutcomeTransform]], + # List[Type[InputTransform]]]. Keeping it as Any saves us from a + # bunch of checked_cast calls within the for loop. + inputs: List[Tuple[str, Any, Dict[str, Any]]], dataset: SupervisedDataset, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - botorch_model_class_args: Any, + botorch_model_class_args: List[str], search_space_digest: SearchSpaceDigest, ) -> None: - for input_name, input_class, input_options, input_object in inputs: - if input_class is None and input_object is None: + for input_name, input_class, input_options in inputs: + if input_class is None: continue if input_name not in botorch_model_class_args: # TODO: We currently only pass in `covar_module` and `likelihood` @@ -515,68 +404,38 @@ def _set_formatted_inputs( f"The BoTorch model class {self.botorch_model_class} does not " f"support the input {input_name}." ) - if input_class is not None and input_object is not None: - raise RuntimeError(f"Got both a class and an object for {input_name}.") - if input_class is not None: - input_options = input_options or {} - - if input_name == "covar_module": - covar_module_with_defaults = covar_module_argparse( - input_class, - dataset=dataset, - botorch_model_class=self.botorch_model_class, - **input_options, - ) - - formatted_model_inputs[input_name] = input_class( - **covar_module_with_defaults - ) - - elif input_name == "input_transform": - - formatted_model_inputs[ - input_name - ] = self._make_botorch_input_transform( - input_classes=input_class, - input_options=input_options, - dataset=dataset, - search_space_digest=search_space_digest, - ) - - elif input_name == "outcome_transform": - - formatted_model_inputs[ - input_name - ] = self._make_botorch_outcome_transform( - input_classes=input_class, - input_options=input_options, - dataset=dataset, - ) - else: - formatted_model_inputs[input_name] = input_class(**input_options) - - else: - formatted_model_inputs[input_name] = input_object + input_options = deepcopy(input_options) or {} + + if input_name == "covar_module": + covar_module_with_defaults = covar_module_argparse( + input_class, + dataset=dataset, + botorch_model_class=self.botorch_model_class, + **input_options, + ) - # Construct input perturbation if doing robust optimization. - robust_digest = search_space_digest.robust_digest - if robust_digest is not None: + formatted_model_inputs[input_name] = input_class( + **covar_module_with_defaults + ) - perturbation = self._make_botorch_input_transform( - input_classes=[InputPerturbation], - dataset=dataset, - search_space_digest=search_space_digest, - input_options={}, - ) + elif input_name == "input_transform": + formatted_model_inputs[input_name] = self._make_botorch_input_transform( + input_classes=input_class, + input_options=input_options, + dataset=dataset, + search_space_digest=search_space_digest, + ) - if formatted_model_inputs.get("input_transform") is not None: - # TODO: Support mixing with user supplied transforms. - raise NotImplementedError( - "User supplied input transforms are not supported " - "in robust optimization." + elif input_name == "outcome_transform": + formatted_model_inputs[ + input_name + ] = self._make_botorch_outcome_transform( + input_classes=input_class, + input_options=input_options, + dataset=dataset, ) else: - formatted_model_inputs["input_transform"] = perturbation + formatted_model_inputs[input_name] = input_class(**input_options) def _make_botorch_input_transform( self, @@ -675,7 +534,7 @@ def fit( candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, - **kwargs: Any, + additional_model_inputs: Optional[Dict[str, Any]] = None, ) -> None: """Fits the underlying BoTorch ``Model`` to ``m`` outcomes. @@ -706,7 +565,10 @@ def fit( the order corresponding to the Xs. state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters. + additional_model_inputs: Additional kwargs to pass to the + model input constructor. """ + self.model_options.update(additional_model_inputs or {}) if self._constructed_manually: logger.debug( "For manually constructed surrogates (via `Surrogate.from_botorch`), " @@ -714,13 +576,10 @@ def fit( "its parameters if `refit=True`." ) else: - _kwargs = dataclasses.asdict(search_space_digest) - _kwargs.update(kwargs) self.construct( datasets=datasets, metric_names=metric_names, search_space_digest=search_space_digest, - **_kwargs, ) self._outcomes = metric_names @@ -827,7 +686,7 @@ def update( candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, - **kwargs: Any, + additional_model_inputs: Optional[Dict[str, Any]] = None, ) -> None: """Updates the surrogate model with new data. In the base ``Surrogate``, just calls ``fit`` after checking that this surrogate was not created @@ -848,6 +707,8 @@ def update( state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters or just set the training data used for interence to new training data. + additional_model_inputs: Additional kwargs to pass to the + model input constructor. """ # NOTE: In the future, could have `incremental` kwarg, in which case # `training_data` could contain just the new data. @@ -865,7 +726,7 @@ def update( candidate_metadata=candidate_metadata, state_dict=state_dict, refit=refit, - **kwargs, + additional_model_inputs=additional_model_inputs, ) def pareto_frontier(self) -> Tuple[Tensor, Tensor]: @@ -910,17 +771,16 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]: } def _extract_construct_model_list_kwargs( - self, - fidelity_features: Sequence[int], - task_features: Sequence[int], - robust_digest: Optional[RobustSearchSpaceDigest] = None, + self, search_space_digest: SearchSpaceDigest ) -> Tuple[ List[int], Optional[int], + List[int], Optional[List[Type[InputTransform]]], - Optional[Dict[str, Dict["str", Any]]], + Dict[str, Dict[str, Any]], ]: - + fidelity_features = search_space_digest.fidelity_features + task_features = search_space_digest.task_features if len(fidelity_features) > 0 and len(task_features) > 0: raise NotImplementedError( "Multi-Fidelity GP models with task_features are " @@ -938,8 +798,7 @@ def _extract_construct_model_list_kwargs( # Construct input perturbation if doing robust optimization. # NOTE: Doing this here rather than in `_set_formatted_inputs` to make sure # we use the same perturbations for each sub-model. - if robust_digest is not None: - + if (robust_digest := search_space_digest.robust_digest) is not None: submodel_input_transform_options = { "InputPerturbation": input_transform_argparse( InputTransform, @@ -964,8 +823,9 @@ def _extract_construct_model_list_kwargs( submodel_input_transform_options = self.input_transform_options return ( - list(fidelity_features), + fidelity_features, task_feature, + search_space_digest.categorical_features, submodel_input_transform_classes, submodel_input_transform_options, ) diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 2223a1467f1..2d6f73f3915 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -31,7 +31,7 @@ from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel from botorch.models.model import Model, ModelList -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ChainedInputTransform from botorch.utils.datasets import SupervisedDataset @@ -108,11 +108,9 @@ def choose_model_class( "errors. Variances should all be specified, or none should be." ) - # Multi-task cases (when `task_features` specified). - if search_space_digest.task_features and all_inferred: - model_class = MultiTaskGP # Unknown observation noise. - elif search_space_digest.task_features: - model_class = FixedNoiseMultiTaskGP # Known observation noise. + # Multi-task case (when `task_features` is specified). + if search_space_digest.task_features: + model_class = MultiTaskGP # Single-task multi-fidelity cases. elif search_space_digest.fidelity_features and all_inferred: diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index 58b6d5db3ce..29a781cb517 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -117,12 +117,12 @@ def setUp(self) -> None: self.search_space_digest = SearchSpaceDigest( feature_names=self.feature_names, bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)], + fidelity_features=self.fidelity_features, target_values={2: 1.0}, ) self.surrogate.construct( datasets=self.training_data, metric_names=self.metric_names, - fidelity_features=self.fidelity_features, search_space_digest=SearchSpaceDigest( feature_names=self.search_space_digest.feature_names[:1], bounds=self.search_space_digest.bounds, diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 25aa9d3b6be..ec8e49aee86 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -345,6 +345,7 @@ def test_fit(self, mock_fit: Mock) -> None: candidate_metadata=self.candidate_metadata, state_dict=None, refit=True, + additional_model_inputs={}, ) # ensure that error is raised when len(metric_names) != len(datasets) with self.assertRaisesRegex( diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 48fc788c3d6..7dffd1ba2f3 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -34,7 +34,7 @@ from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.model import Model, ModelList # noqa: F401 -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood from botorch.models.transforms.input import InputPerturbation, Normalize from botorch.models.transforms.outcome import Standardize @@ -198,7 +198,6 @@ def test_copy_options(self) -> None: ) # Change the lengthscales of one model and make sure the other isn't changed models[0].covar_module.base_kernel.lengthscale += 1 - self.assertTrue( torch.allclose( model1_old_lengtscale, @@ -364,15 +363,11 @@ def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None: search_space_digest=self.search_space_digest, ) mock_construct_inputs.assert_called_with( - training_data=self.training_data[0], some_option="some_value" - ) - - # botorch_model_class must be set to construct single model Surrogate - with self.assertRaisesRegex(ValueError, "botorch_model_class must be set"): - surrogate = Surrogate() - surrogate._construct_model( - dataset=self.training_data[0], - search_space_digest=self.search_space_digest, + training_data=self.training_data[0], + some_option="some_value", + fidelity_features=[], + task_feature=None, + categorical_features=[], ) def test_construct_custom_model(self) -> None: @@ -404,16 +399,16 @@ def test_construct_custom_model(self) -> None: metric_names=self.metric_names, search_space_digest=self.search_space_digest, ) - self.assertEqual(type(surrogate._model.likelihood), GaussianLikelihood) + model = not_none(surrogate._model) + self.assertEqual(type(model.likelihood), GaussianLikelihood) self.assertEqual( - # pyre-fixme[16]: Optional type has no attribute `likelihood`. - surrogate._model.likelihood.noise_covar.raw_noise_constraint, - noise_constraint, + # Checking equality of __dict__'s since Interval does not define __eq__. + model.likelihood.noise_covar.raw_noise_constraint.__dict__, + noise_constraint.__dict__, ) self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood) - self.assertEqual(type(surrogate._model.covar_module), RBFKernel) - # pyre-fixme[16]: Optional type has no attribute `covar_module`. - self.assertEqual(surrogate._model.covar_module.ard_num_dims, 1) + self.assertEqual(type(model.covar_module), RBFKernel) + self.assertEqual(model.covar_module.ard_num_dims, 1) @patch( f"{CURRENT_PATH}.SaasFullyBayesianSingleTaskGP.load_state_dict", @@ -498,7 +493,6 @@ def test_predict(self, mock_predict: Mock) -> None: surrogate.construct( datasets=self.training_data, metric_names=self.metric_names, - fidelity_features=self.search_space_digest.fidelity_features, search_space_digest=self.search_space_digest, ) surrogate.predict(X=self.Xs[0]) @@ -510,7 +504,6 @@ def test_best_in_sample_point(self) -> None: surrogate.construct( datasets=self.training_data, metric_names=self.metric_names, - fidelity_features=self.search_space_digest.fidelity_features, search_space_digest=self.search_space_digest, ) # `best_in_sample_point` requires `objective_weights` @@ -637,6 +630,7 @@ def test_update( candidate_metadata=None, refit=self.refit, state_dict={"key": torch.zeros(1)}, + additional_model_inputs=None, ) # Check that the training data is correctly passed through to the @@ -799,10 +793,10 @@ def setUp(self) -> None: self.outcomes = ["outcome_1", "outcome_2"] self.mll_class = ExactMarginalLogLikelihood self.dtype = torch.float + self.task_features = [0] self.search_space_digest = SearchSpaceDigest( - feature_names=[], bounds=[], task_features=[0] + feature_names=[], bounds=[], task_features=self.task_features ) - self.task_features = [0] Xs1, Ys1, Yvars1, bounds, _, _, _ = get_torch_test_data( dtype=self.dtype, task_features=self.search_space_digest.task_features ) @@ -848,9 +842,9 @@ def setUp(self) -> None: search_space_digest=self.search_space_digest, ), } - self.botorch_model_class = FixedNoiseMultiTaskGP + self.botorch_model_class = MultiTaskGP for submodel_cls in self.botorch_submodel_class_per_outcome.values(): - self.assertEqual(submodel_cls, FixedNoiseMultiTaskGP) + self.assertEqual(submodel_cls, MultiTaskGP) self.Xs = Xs1 + Xs2 self.Ys = Ys1 + Ys2 self.Yvars = Yvars1 + Yvars2 @@ -878,7 +872,7 @@ def setUp(self) -> None: RANK: 1, } self.surrogate = Surrogate( - botorch_model_class=FixedNoiseMultiTaskGP, + botorch_model_class=MultiTaskGP, mll_class=self.mll_class, model_options=self.submodel_options_per_outcome, ) @@ -897,19 +891,21 @@ def test_init(self) -> None: self.surrogate.model @patch.object( - FixedNoiseMultiTaskGP, + MultiTaskGP, "construct_inputs", - wraps=FixedNoiseMultiTaskGP.construct_inputs, + wraps=MultiTaskGP.construct_inputs, ) def test_construct_per_outcome_options( self, mock_MTGP_construct_inputs: Mock ) -> None: + self.surrogate.model_options.update({"output_tasks": [2]}) self.surrogate.construct( datasets=self.fixed_noise_training_data, metric_names=self.outcomes, - output_tasks=[2], - search_space_digest=self.search_space_digest, - task_features=self.task_features, + search_space_digest=dataclasses.replace( + self.search_space_digest, + task_features=self.task_features, + ), ) # Should construct inputs for MTGP twice. self.assertEqual(len(mock_MTGP_construct_inputs.call_args_list), 2) @@ -919,6 +915,7 @@ def test_construct_per_outcome_options( # `call_args` is a tuple of (args, kwargs), and we check kwargs. mock_MTGP_construct_inputs.call_args_list[idx][1], { + "categorical_features": [], "fidelity_features": [], "task_feature": self.task_features[0], "training_data": SupervisedDataset( @@ -947,10 +944,10 @@ def test_construct_per_outcome_options_no_Yvar(self, _) -> None: # Test that splitting the training data works correctly when Yvar is None. surrogate.construct( datasets=self.supervised_training_data, - task_features=self.task_features, metric_names=self.outcomes, search_space_digest=SearchSpaceDigest( feature_names=self.feature_names, + task_features=self.task_features, bounds=self.bounds, ), ) @@ -960,9 +957,9 @@ def test_construct_per_outcome_options_no_Yvar(self, _) -> None: self.assertEqual(len(not_none(surrogate._training_data)), 2) @patch.object( - FixedNoiseMultiTaskGP, + MultiTaskGP, "construct_inputs", - wraps=FixedNoiseMultiTaskGP.construct_inputs, + wraps=MultiTaskGP.construct_inputs, ) def test_construct_per_outcome_error_raises( self, mock_MTGP_construct_inputs: Mock @@ -981,10 +978,10 @@ def test_construct_per_outcome_error_raises( surrogate.construct( datasets=self.fixed_noise_training_data, metric_names=self.outcomes, - task_features=self.task_features, - fidelity_features=[1], search_space_digest=SearchSpaceDigest( feature_names=self.feature_names, + task_features=self.task_features, + fidelity_features=[1], bounds=self.bounds, ), ) @@ -995,9 +992,9 @@ def test_construct_per_outcome_error_raises( surrogate.construct( datasets=self.fixed_noise_training_data, metric_names=self.outcomes, - task_features=[0, 1], search_space_digest=SearchSpaceDigest( feature_names=self.feature_names, + task_features=[0, 1], bounds=self.bounds, ), ) diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index 3e1b40f52a3..fbfe39028b7 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -38,7 +38,7 @@ ) from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.model import ModelList -from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import ( ChainedInputTransform, InputPerturbation, @@ -140,26 +140,17 @@ def test_choose_model_class_task_features(self) -> None: feature_names=[], bounds=[], task_features=[1, 2] ), ) - # With fidelity features and unknown variances, use SingleTaskMultiFidelityGP. - self.assertEqual( - MultiTaskGP, - choose_model_class( - datasets=self.supervised_datasets, - search_space_digest=SearchSpaceDigest( - feature_names=[], bounds=[], task_features=[1] - ), - ), - ) - # With fidelity features and known variances, use FixedNoiseMultiFidelityGP. - self.assertEqual( - FixedNoiseMultiTaskGP, - choose_model_class( - datasets=self.fixed_noise_datasets, - search_space_digest=SearchSpaceDigest( - feature_names=[], bounds=[], task_features=[1] + # With task features use MultiTaskGP. + for datasets in (self.supervised_datasets, self.fixed_noise_datasets): + self.assertEqual( + MultiTaskGP, + choose_model_class( + datasets=datasets, + search_space_digest=SearchSpaceDigest( + feature_names=[], bounds=[], task_features=[1] + ), ), - ), - ) + ) def test_choose_model_class_discrete_features(self) -> None: # With discrete features, use MixedSingleTaskyGP.