diff --git a/ax/models/tests/test_botorch_defaults.py b/ax/models/tests/test_botorch_defaults.py index ac69fc7e8e3..bb9e39de380 100644 --- a/ax/models/tests/test_botorch_defaults.py +++ b/ax/models/tests/test_botorch_defaults.py @@ -32,6 +32,7 @@ ) from botorch.acquisition.objective import ConstrainedMCObjective from botorch.acquisition.penalized import L1PenaltyObjective, PenalizedMCObjective +from botorch.exceptions.errors import UnsupportedError from botorch.models.gp_regression import SingleTaskGP from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP from botorch.models.multitask import MultiTaskGP @@ -215,6 +216,15 @@ def test_get_model(self) -> None: self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood) self.assertEqual(covar_module, model.covar_module) + # test input warping dimension checks. + with self.assertRaisesRegex(UnsupportedError, "batched multi output models"): + _get_model( + X=torch.ones(4, 3, 2), + Y=torch.ones(4, 3, 2), + Yvar=torch.zeros(4, 3, 2), + use_input_warping=True, + ) + @mock.patch("ax.models.torch.botorch_defaults._get_model", wraps=_get_model) @fast_botorch_optimize # pyre-fixme[3]: Return type must be annotated. @@ -468,7 +478,7 @@ def test_get_customized_covar_module(self) -> None: covar_module = _get_customized_covar_module( covar_module_prior_dict={}, ard_num_dims=ard_num_dims, - batch_shape=batch_shape, + aug_batch_shape=batch_shape, task_feature=None, ) self.assertIsInstance(covar_module, Module) @@ -495,7 +505,7 @@ def test_get_customized_covar_module(self) -> None: "outputscale_prior": GammaPrior(2.0, 12.0), }, ard_num_dims=ard_num_dims, - batch_shape=batch_shape, + aug_batch_shape=batch_shape, task_feature=3, ) self.assertIsInstance(covar_module, Module) diff --git a/ax/models/torch/botorch_defaults.py b/ax/models/torch/botorch_defaults.py index 97faa4adb68..e12dfd200d6 100644 --- a/ax/models/torch/botorch_defaults.py +++ b/ax/models/torch/botorch_defaults.py @@ -711,7 +711,6 @@ def _get_model( is_nan = torch.isnan(Yvar) any_nan_Yvar = torch.any(is_nan) all_nan_Yvar = torch.all(is_nan) - batch_shape = _get_batch_shape(X, Y) if any_nan_Yvar and not all_nan_Yvar: if task_feature: # TODO (jej): Replace with inferred noise before making perf judgements. @@ -722,10 +721,14 @@ def _get_model( "errors. Variances should all be specified, or none should be." ) if use_input_warping: + if Y.shape[-1] > 1 and X.ndim > 2: + raise UnsupportedError( + "Input warping is not supported for batched multi output models." + ) warp_tf = get_warping_transform( d=X.shape[-1], task_feature=task_feature, - batch_shape=batch_shape, + batch_shape=X.shape[:-2], ) else: warp_tf = None @@ -741,7 +744,7 @@ def _get_model( covar_module = _get_customized_covar_module( covar_module_prior_dict=covar_module_prior_dict, ard_num_dims=X.shape[-1], - batch_shape=batch_shape, + aug_batch_shape=_get_aug_batch_shape(X, Y), task_feature=task_feature, ) @@ -804,7 +807,7 @@ def _get_model( def _get_customized_covar_module( covar_module_prior_dict: Dict[str, Prior], ard_num_dims: int, - batch_shape: torch.Size, + aug_batch_shape: torch.Size, task_feature: Optional[int] = None, ) -> Kernel: """Construct a GP kernel based on customized prior dict. @@ -812,9 +815,10 @@ def _get_customized_covar_module( Args: covar_module_prior_dict: Dict. The keys are the names of the prior and values are the priors. e.g. {"lengthscale_prior": GammaPrior(3.0, 6.0)}. - ard_num_dims: The dimension of the input, including task features - batch_shape: The batch_shape of the model - task_feature: The index of the task feature + ard_num_dims: The dimension of the inputs, including task features. + aug_batch_shape: The output dimension augmented batch shape of the model + (different from the batch shape for batched multi-output models). + task_feature: The index of the task feature. """ # TODO: add more checks of covar_module_prior_dict if task_feature is not None: @@ -823,19 +827,19 @@ def _get_customized_covar_module( MaternKernel( nu=2.5, ard_num_dims=ard_num_dims, - batch_shape=batch_shape, + batch_shape=aug_batch_shape, lengthscale_prior=covar_module_prior_dict.get( "lengthscale_prior", GammaPrior(3.0, 6.0) ), ), - batch_shape=batch_shape, + batch_shape=aug_batch_shape, outputscale_prior=covar_module_prior_dict.get( "outputscale_prior", GammaPrior(2.0, 0.15) ), ) -def _get_batch_shape(X: Tensor, Y: Tensor) -> torch.Size: +def _get_aug_batch_shape(X: Tensor, Y: Tensor) -> torch.Size: """Obtain the output-augmented batch shape of GP model. Args: