Skip to content

Commit

Permalink
Fix usage of batch shape for warp transform (#1994)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1994

Follows up pytorch/botorch#2109 to fix potential for erroneous usage.

Reviewed By: Balandat

Differential Revision: D51369374

fbshipit-source-id: e547f773d1fd6a8d22a986ba404afe8fc5e91fb9
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 16, 2023
1 parent 1fe5d57 commit 481d749
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
14 changes: 12 additions & 2 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from botorch.acquisition.objective import ConstrainedMCObjective
from botorch.acquisition.penalized import L1PenaltyObjective, PenalizedMCObjective
from botorch.exceptions.errors import UnsupportedError
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.multitask import MultiTaskGP
Expand Down Expand Up @@ -215,6 +216,15 @@ def test_get_model(self) -> None:
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(covar_module, model.covar_module)

# test input warping dimension checks.
with self.assertRaisesRegex(UnsupportedError, "batched multi output models"):
_get_model(
X=torch.ones(4, 3, 2),
Y=torch.ones(4, 3, 2),
Yvar=torch.zeros(4, 3, 2),
use_input_warping=True,
)

@mock.patch("ax.models.torch.botorch_defaults._get_model", wraps=_get_model)
@fast_botorch_optimize
# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -468,7 +478,7 @@ def test_get_customized_covar_module(self) -> None:
covar_module = _get_customized_covar_module(
covar_module_prior_dict={},
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
aug_batch_shape=batch_shape,
task_feature=None,
)
self.assertIsInstance(covar_module, Module)
Expand All @@ -495,7 +505,7 @@ def test_get_customized_covar_module(self) -> None:
"outputscale_prior": GammaPrior(2.0, 12.0),
},
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
aug_batch_shape=batch_shape,
task_feature=3,
)
self.assertIsInstance(covar_module, Module)
Expand Down
24 changes: 14 additions & 10 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,6 @@ def _get_model(
is_nan = torch.isnan(Yvar)
any_nan_Yvar = torch.any(is_nan)
all_nan_Yvar = torch.all(is_nan)
batch_shape = _get_batch_shape(X, Y)
if any_nan_Yvar and not all_nan_Yvar:
if task_feature:
# TODO (jej): Replace with inferred noise before making perf judgements.
Expand All @@ -722,10 +721,14 @@ def _get_model(
"errors. Variances should all be specified, or none should be."
)
if use_input_warping:
if Y.shape[-1] > 1 and X.ndim > 2:
raise UnsupportedError(
"Input warping is not supported for batched multi output models."
)
warp_tf = get_warping_transform(
d=X.shape[-1],
task_feature=task_feature,
batch_shape=batch_shape,
batch_shape=X.shape[:-2],
)
else:
warp_tf = None
Expand All @@ -741,7 +744,7 @@ def _get_model(
covar_module = _get_customized_covar_module(
covar_module_prior_dict=covar_module_prior_dict,
ard_num_dims=X.shape[-1],
batch_shape=batch_shape,
aug_batch_shape=_get_aug_batch_shape(X, Y),
task_feature=task_feature,
)

Expand Down Expand Up @@ -804,17 +807,18 @@ def _get_model(
def _get_customized_covar_module(
covar_module_prior_dict: Dict[str, Prior],
ard_num_dims: int,
batch_shape: torch.Size,
aug_batch_shape: torch.Size,
task_feature: Optional[int] = None,
) -> Kernel:
"""Construct a GP kernel based on customized prior dict.
Args:
covar_module_prior_dict: Dict. The keys are the names of the prior and values
are the priors. e.g. {"lengthscale_prior": GammaPrior(3.0, 6.0)}.
ard_num_dims: The dimension of the input, including task features
batch_shape: The batch_shape of the model
task_feature: The index of the task feature
ard_num_dims: The dimension of the inputs, including task features.
aug_batch_shape: The output dimension augmented batch shape of the model
(different from the batch shape for batched multi-output models).
task_feature: The index of the task feature.
"""
# TODO: add more checks of covar_module_prior_dict
if task_feature is not None:
Expand All @@ -823,19 +827,19 @@ def _get_customized_covar_module(
MaternKernel(
nu=2.5,
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
batch_shape=aug_batch_shape,
lengthscale_prior=covar_module_prior_dict.get(
"lengthscale_prior", GammaPrior(3.0, 6.0)
),
),
batch_shape=batch_shape,
batch_shape=aug_batch_shape,
outputscale_prior=covar_module_prior_dict.get(
"outputscale_prior", GammaPrior(2.0, 0.15)
),
)


def _get_batch_shape(X: Tensor, Y: Tensor) -> torch.Size:
def _get_aug_batch_shape(X: Tensor, Y: Tensor) -> torch.Size:
"""Obtain the output-augmented batch shape of GP model.
Args:
Expand Down

0 comments on commit 481d749

Please sign in to comment.