From 4a4a5bd29404e8605a7711ccb78233bd0e6f3da0 Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Thu, 16 Nov 2023 08:29:03 -0800 Subject: [PATCH] Clean up is_fully_bayesian (#2108) Summary: X-link: https://github.com/facebook/Ax/pull/1992 Pull Request resolved: https://github.com/pytorch/botorch/pull/2108 This attempts to clean up the usage of `is_fully_bayesian` and also separately treat fully Bayesian models from ensemble models. The main changes in diff are to: - Add an `_is_fully_bayesian` attribute to `Model`. This is `True` for fully Bayesian models that rely on Pyro/NUTS to be fitted (they need some special handling for fitting and `state_dict` loading/saving. - Add an `_is_ensemble` attribute to `Model`. This indicates whether the model is a collection of multiple models that are stored in an additional batch dimension. This is hopefully a better classification, but I'm open to a different name here. - Rename `FullyBayesianPosterior` to `GaussianMixturePosterior` since that is more descriptive and plays better with the other changes. Reviewed By: esantorella Differential Revision: D50884342 fbshipit-source-id: 0ba603416c1823026c4fdf2e445cefdf8036cda8 --- botorch/acquisition/multi_objective/logei.py | 8 +- .../multi_objective/monte_carlo.py | 4 +- botorch/acquisition/multi_objective/utils.py | 4 +- botorch/acquisition/utils.py | 4 +- botorch/models/fully_bayesian.py | 12 ++- botorch/models/fully_bayesian_multitask.py | 12 ++- botorch/models/gpytorch.py | 12 +-- botorch/models/model.py | 8 +- botorch/posteriors/__init__.py | 6 +- botorch/posteriors/fully_bayesian.py | 26 +++-- botorch/posteriors/posterior_list.py | 27 ++--- botorch/utils/gp_sampling.py | 6 +- botorch/utils/transforms.py | 40 +++---- test/models/test_fully_bayesian.py | 6 +- test/models/test_fully_bayesian_multitask.py | 8 +- test/utils/test_gp_sampling.py | 4 +- test/utils/test_transforms.py | 102 ++++++++++++++---- 17 files changed, 191 insertions(+), 98 deletions(-) diff --git a/botorch/acquisition/multi_objective/logei.py b/botorch/acquisition/multi_objective/logei.py index 35bd308ba8..b43ee75b4b 100644 --- a/botorch/acquisition/multi_objective/logei.py +++ b/botorch/acquisition/multi_objective/logei.py @@ -38,7 +38,7 @@ ) from botorch.utils.transforms import ( concatenate_pending_points, - is_fully_bayesian, + is_ensemble, match_batch_shape, t_batch_mode_transform, ) @@ -454,9 +454,9 @@ def forward(self, X: Tensor) -> Tensor: # 1) X and X, and # 2) X and X_baseline. posterior = self.model.posterior(X_full) - # Account for possible one-to-many transform and the MCMC batch dimension in - # `SaasFullyBayesianSingleTaskGP` - event_shape_lag = 1 if is_fully_bayesian(self.model) else 2 + # Account for possible one-to-many transform and the model batch dimensions in + # ensemble models. + event_shape_lag = 1 if is_ensemble(self.model) else 2 n_w = ( posterior._extended_shape()[X_full.dim() - event_shape_lag] // X_full.shape[-2] diff --git a/botorch/acquisition/multi_objective/monte_carlo.py b/botorch/acquisition/multi_objective/monte_carlo.py index 68cc2eb21c..29ac6655c3 100644 --- a/botorch/acquisition/multi_objective/monte_carlo.py +++ b/botorch/acquisition/multi_objective/monte_carlo.py @@ -49,7 +49,7 @@ from botorch.utils.objective import compute_smoothed_feasibility_indicator from botorch.utils.transforms import ( concatenate_pending_points, - is_fully_bayesian, + is_ensemble, match_batch_shape, t_batch_mode_transform, ) @@ -453,7 +453,7 @@ def forward(self, X: Tensor) -> Tensor: posterior = self.model.posterior(X_full) # Account for possible one-to-many transform and the MCMC batch dimension in # `SaasFullyBayesianSingleTaskGP` - event_shape_lag = 1 if is_fully_bayesian(self.model) else 2 + event_shape_lag = 1 if is_ensemble(self.model) else 2 n_w = ( posterior._extended_shape()[X_full.dim() - event_shape_lag] // X_full.shape[-2] diff --git a/botorch/acquisition/multi_objective/utils.py b/botorch/acquisition/multi_objective/utils.py index 6b95f0c5e1..ac4cf85613 100644 --- a/botorch/acquisition/multi_objective/utils.py +++ b/botorch/acquisition/multi_objective/utils.py @@ -40,7 +40,7 @@ from botorch.utils.multi_objective.pareto import is_non_dominated from botorch.utils.objective import compute_feasibility_indicator from botorch.utils.sampling import draw_sobol_samples -from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.transforms import is_ensemble from torch import Tensor @@ -110,7 +110,7 @@ def prune_inferior_points_multi_objective( with `N_nz` the number of points in `X` that have non-zero (empirical, under `num_samples` samples) probability of being pareto optimal. """ - if marginalize_dim is None and is_fully_bayesian(model): + if marginalize_dim is None and is_ensemble(model): # TODO: Properly deal with marginalizing fully Bayesian models marginalize_dim = MCMC_DIM diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index ec60899fff..dc7760cc3f 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -31,7 +31,7 @@ from botorch.sampling.pathwise import draw_matheron_paths from botorch.utils.objective import compute_feasibility_indicator from botorch.utils.sampling import optimize_posterior_samples -from botorch.utils.transforms import is_fully_bayesian, normalize_indices +from botorch.utils.transforms import is_ensemble, normalize_indices from torch import Tensor @@ -263,7 +263,7 @@ def prune_inferior_points( with `N_nz` the number of points in `X` that have non-zero (empirical, under `num_samples` samples) probability of being the best point. """ - if marginalize_dim is None and is_fully_bayesian(model): + if marginalize_dim is None and is_ensemble(model): # TODO: Properly deal with marginalizing fully Bayesian models marginalize_dim = MCMC_DIM diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index 76f7804f9f..7878c34ade 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -43,7 +43,7 @@ from botorch.models.transforms.outcome import OutcomeTransform from botorch.models.utils import validate_input_scaling from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM from gpytorch.constraints import GreaterThan from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import MaternKernel, ScaleKernel @@ -327,6 +327,9 @@ class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel): >>> posterior = saas_gp.posterior(test_X) """ + _is_fully_bayesian = True + _is_ensemble = True + def __init__( self, train_X: Tensor, @@ -508,7 +511,7 @@ def posterior( observation_noise: bool = False, posterior_transform: Optional[PosteriorTransform] = None, **kwargs: Any, - ) -> FullyBayesianPosterior: + ) -> GaussianMixturePosterior: r"""Computes the posterior over model outputs at the provided points. Args: @@ -526,7 +529,8 @@ def posterior( posterior_transform: An optional PosteriorTransform. Returns: - A `FullyBayesianPosterior` object. Includes observation noise if specified. + A `GaussianMixturePosterior` object. Includes observation noise + if specified. """ self._check_if_fitted() posterior = super().posterior( @@ -536,5 +540,5 @@ def posterior( posterior_transform=posterior_transform, **kwargs, ) - posterior = FullyBayesianPosterior(distribution=posterior.distribution) + posterior = GaussianMixturePosterior(distribution=posterior.distribution) return posterior diff --git a/botorch/models/fully_bayesian_multitask.py b/botorch/models/fully_bayesian_multitask.py index 8c027f75b4..b621f65d27 100644 --- a/botorch/models/fully_bayesian_multitask.py +++ b/botorch/models/fully_bayesian_multitask.py @@ -23,7 +23,7 @@ from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import MaternKernel @@ -189,6 +189,9 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP): >>> posterior = mtsaas_gp.posterior(test_X) """ + _is_fully_bayesian = True + _is_ensemble = True + def __init__( self, train_X: Tensor, @@ -335,11 +338,12 @@ def posterior( observation_noise: bool = False, posterior_transform: Optional[PosteriorTransform] = None, **kwargs: Any, - ) -> FullyBayesianPosterior: + ) -> GaussianMixturePosterior: r"""Computes the posterior over model outputs at the provided points. Returns: - A `FullyBayesianPosterior` object. Includes observation noise if specified. + A `GaussianMixturePosterior` object. Includes observation noise + if specified. """ self._check_if_fitted() posterior = super().posterior( @@ -349,7 +353,7 @@ def posterior( posterior_transform=posterior_transform, **kwargs, ) - posterior = FullyBayesianPosterior(distribution=posterior.distribution) + posterior = GaussianMixturePosterior(distribution=posterior.distribution) return posterior def forward(self, X: Tensor) -> MultivariateNormal: diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 0ebfed75ba..315dd78911 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -34,9 +34,9 @@ mod_batch_shape, multioutput_to_batch_mode_transform, ) -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior from botorch.posteriors.gpytorch import GPyTorchPosterior -from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.transforms import is_ensemble from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from torch import Tensor @@ -619,7 +619,7 @@ def posterior( - If no `posterior_transform` is provided and the component models have no `outcome_transform`, or if the component models only use linear outcome transforms like `Standardize` (i.e. not `Log`), returns a - `GPyTorchPosterior` or `FullyBayesianPosterior` object, + `GPyTorchPosterior` or `GaussianMixturePosterior` object, representing `batch_shape` joint distributions over `q` points and the outputs selected by `output_indices` each. Includes measurement noise if `observation_noise` is specified. @@ -650,16 +650,16 @@ def posterior( mvns = [p.distribution for p in posterior.posteriors] # Combining MTMVNs into a single MTMVN is currently not supported. if not any(isinstance(m, MultitaskMultivariateNormal) for m in mvns): - # Return the result as a GPyTorchPosterior/FullyBayesianPosterior. + # Return the result as a GPyTorchPosterior/GaussianMixturePosterior. mvn = ( mvns[0] if len(mvns) == 1 else MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) ) - if any(is_fully_bayesian(m) for m in self.models): + if any(is_ensemble(m) for m in self.models): # Mixing fully Bayesian and other GP models is currently # not supported. - posterior = FullyBayesianPosterior(distribution=mvn) + posterior = GaussianMixturePosterior(distribution=mvn) else: posterior = GPyTorchPosterior(distribution=mvn) if posterior_transform is not None: diff --git a/botorch/models/model.py b/botorch/models/model.py index 434f43ce7a..8c2be562dd 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -70,17 +70,23 @@ class Model(Module, ABC): `Tensor` or `Module` type are automatically registered so they can be moved and/or cast with the `to` method, automatically differentiated, and used with CUDA. - Args: + Attributes: _has_transformed_inputs: A boolean denoting whether `train_inputs` are currently stored as transformed or not. _original_train_inputs: A Tensor storing the original train inputs for use in `_revert_to_original_inputs`. Note that this is necessary since transform / untransform cycle introduces numerical errors which lead to upstream errors during training. + _is_fully_bayesian: Returns `True` if this is a fully Bayesian model. + _is_ensemble: Returns `True` if this model consists of multiple models + that are stored in an additional batch dimension. This is true for the fully + Bayesian models. """ # noqa: E501 _has_transformed_inputs: bool = False _original_train_inputs: Optional[Tensor] = None + _is_fully_bayesian = False + _is_ensemble = False @abstractmethod def posterior( diff --git a/botorch/posteriors/__init__.py b/botorch/posteriors/__init__.py index 6c5d16d036..a2b2211f55 100644 --- a/botorch/posteriors/__init__.py +++ b/botorch/posteriors/__init__.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. from botorch.posteriors.deterministic import DeterministicPosterior -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior +from botorch.posteriors.fully_bayesian import ( + FullyBayesianPosterior, + GaussianMixturePosterior, +) from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.posteriors.higher_order import HigherOrderGPPosterior from botorch.posteriors.multitask import MultitaskGPPosterior @@ -16,6 +19,7 @@ __all__ = [ "DeterministicPosterior", + "GaussianMixturePosterior", "FullyBayesianPosterior", "GPyTorchPosterior", "HigherOrderGPPosterior", diff --git a/botorch/posteriors/fully_bayesian.py b/botorch/posteriors/fully_bayesian.py index 9b36553f36..bf5c133650 100644 --- a/botorch/posteriors/fully_bayesian.py +++ b/botorch/posteriors/fully_bayesian.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import Callable, Optional, Tuple +from warnings import warn import torch from botorch.posteriors.gpytorch import GPyTorchPosterior @@ -54,7 +55,7 @@ def batched_bisect( return center -def _quantile(posterior: FullyBayesianPosterior, value: Tensor) -> Tensor: +def _quantile(posterior: GaussianMixturePosterior, value: Tensor) -> Tensor: r"""Compute the posterior quantiles for the mixture of models.""" if value.numel() > 1: return torch.stack( @@ -78,13 +79,13 @@ def _quantile(posterior: FullyBayesianPosterior, value: Tensor) -> Tensor: ) -class FullyBayesianPosterior(GPyTorchPosterior): - r"""A posterior for a fully Bayesian model. +class GaussianMixturePosterior(GPyTorchPosterior): + r"""A Gaussian mixture posterior. The MCMC batch dimension that corresponds to the models in the mixture is located at `MCMC_DIM` (defined at the top of this file). Note that while each MCMC sample - corresponds to a Gaussian posterior, the fully Bayesian posterior is rather a - mixture of Gaussian distributions. + corresponds to a Gaussian posterior, the posterior is rather a mixture of Gaussian + distributions. """ def __init__(self, distribution: MultivariateNormal) -> None: @@ -137,7 +138,14 @@ def batch_range(self) -> Tuple[int, int]: provide consistency in the acquisition values, i.e., to ensure that a candidate produces same value regardless of its position on the t-batch. """ - if self._is_mt: - return (0, -2) - else: - return (0, -1) + return (0, -2) if self._is_mt else (0, -1) + + +class FullyBayesianPosterior(GaussianMixturePosterior): + """For backwards compatibility.""" + + warn( + "`FullyBayesianPosterior` is marked for deprecation, consider using " + "`GaussianMixturePosterior` instead.", + DeprecationWarning, + ) diff --git a/botorch/posteriors/posterior_list.py b/botorch/posteriors/posterior_list.py index 0ad8eaa0b8..e897afdfa1 100644 --- a/botorch/posteriors/posterior_list.py +++ b/botorch/posteriors/posterior_list.py @@ -15,7 +15,8 @@ from typing import Any, List, Optional import torch -from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM +from botorch.posteriors import FullyBayesianPosterior +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM from botorch.posteriors.posterior import Posterior from torch import Tensor @@ -23,8 +24,8 @@ class PosteriorList(Posterior): r"""A Posterior represented by a list of independent Posteriors. - When at least one of the posteriors is a `FullyBayesianPosterior`, the other - posteriors are expanded to match the size of the `FullyBayesianPosterior`. + When at least one of the posteriors is a `GaussianMixturePosterior`, the other + posteriors are expanded to match the size of the `GaussianMixturePosterior`. """ def __init__(self, *posteriors: Posterior) -> None: @@ -44,16 +45,16 @@ def __init__(self, *posteriors: Posterior) -> None: self.posteriors = list(posteriors) @cached_property - def _is_fully_bayesian(self) -> bool: - r"""Check if any of the posteriors is a `FullyBayesianPosterior`.""" - return any(isinstance(p, FullyBayesianPosterior) for p in self.posteriors) + def _is_gaussian_mixture(self) -> bool: + r"""Check if any of the posteriors is a `GaussianMixturePosterior`.""" + return any(isinstance(p, GaussianMixturePosterior) for p in self.posteriors) def _get_mcmc_batch_dimension(self) -> int: """Return the number of MCMC samples in the corresponding batch dimension.""" mcmc_samples = [ p.mean.shape[MCMC_DIM] for p in self.posteriors - if isinstance(p, FullyBayesianPosterior) + if isinstance(p, (GaussianMixturePosterior, FullyBayesianPosterior)) ] if len(set(mcmc_samples)) > 1: raise NotImplementedError( @@ -70,12 +71,12 @@ def _reshape_tensor(X: Tensor, mcmc_samples: int) -> Tensor: def _reshape_and_cat(self, tensors: List[Tensor]): r"""Reshape, if needed, and concatenate (across dim=-1) a list of tensors.""" - if self._is_fully_bayesian: + if self._is_gaussian_mixture: mcmc_samples = self._get_mcmc_batch_dimension() return torch.cat( [ x - if isinstance(p, FullyBayesianPosterior) + if isinstance(p, GaussianMixturePosterior) else self._reshape_tensor(x, mcmc_samples=mcmc_samples) for x, p in zip(tensors, self.posteriors) ], @@ -112,16 +113,18 @@ def _extended_shape( r"""Returns the shape of the samples produced by the posterior with the given `sample_shape`. - If there's at least one `FullyBayesianPosterior`, the MCMC dimension + If there's at least one `GaussianMixturePosterior`, the MCMC dimension is included the `_extended_shape`. """ - if self._is_fully_bayesian: + if self._is_gaussian_mixture: mcmc_shape = torch.Size([self._get_mcmc_batch_dimension()]) extend_dim = MCMC_DIM + 1 # The dimension to inject MCMC shape. extended_shapes = [] for p in self.posteriors: es = p._extended_shape(sample_shape=sample_shape) - if self._is_fully_bayesian and not isinstance(p, FullyBayesianPosterior): + if self._is_gaussian_mixture and not isinstance( + p, GaussianMixturePosterior + ): # Extend the shapes of non-fully Bayesian ones to match. extended_shapes.append(es[:extend_dim] + mcmc_shape + es[extend_dim:]) else: diff --git a/botorch/utils/gp_sampling.py b/botorch/utils/gp_sampling.py index 7d57212e45..b5c0284f9b 100644 --- a/botorch/utils/gp_sampling.py +++ b/botorch/utils/gp_sampling.py @@ -17,7 +17,7 @@ from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.utils.sampling import manual_seed -from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.transforms import is_ensemble from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel from linear_operator.utils.cholesky import psd_safe_cholesky from torch import Tensor @@ -503,7 +503,7 @@ def get_gp_samples( models[m].outcome_transform = _octf if _intf is not None: base_gp_samples.models[m].input_transform = _intf - base_gp_samples.is_fully_bayesian = is_fully_bayesian(model=model) + base_gp_samples._is_ensemble = is_ensemble(model=model) return base_gp_samples elif n_samples > 1: base_gp_samples = get_deterministic_model_multi_samples( @@ -522,5 +522,5 @@ def get_gp_samples( if octf is not None: base_gp_samples.outcome_transform = octf model.outcome_transform = octf - base_gp_samples.is_fully_bayesian = is_fully_bayesian(model=model) + base_gp_samples._is_ensemble = is_ensemble(model=model) return base_gp_samples diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 729bd591b7..e7b96a08ce 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -165,33 +165,35 @@ def _verify_output_shape(acqf: Any, X: Tensor, output: Tensor) -> bool: def is_fully_bayesian(model: Model) -> bool: - r"""Check if at least one model is a SaasFullyBayesianSingleTaskGP + r"""Check if at least one model is a fully Bayesian model. Args: model: A BoTorch model (may be a `ModelList` or `ModelListGP`) - d: The dimension of the tensor to index. Returns: - True if at least one model is a `SaasFullyBayesianSingleTaskGP` + True if at least one model is a fully Bayesian model. """ from botorch.models import ModelList - from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP - from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP - full_bayesian_model_cls = ( - SaasFullyBayesianSingleTaskGP, - SaasFullyBayesianMultiTaskGP, - ) + if isinstance(model, ModelList): + return any(is_fully_bayesian(m) for m in model.models) + return getattr(model, "_is_fully_bayesian", False) - if isinstance(model, full_bayesian_model_cls) or getattr( - model, "is_fully_bayesian", False - ): - return True - elif isinstance(model, ModelList): - for m in model.models: - if is_fully_bayesian(m): - return True - return False + +def is_ensemble(model: Model) -> bool: + r"""Check if at least one model is an ensemble model. + + Args: + model: A BoTorch model (may be a `ModelList` or `ModelListGP`) + + Returns: + True if at least one model is an ensemble model. + """ + from botorch.models import ModelList + + if isinstance(model, ModelList): + return any(is_ensemble(m) for m in model.models) + return getattr(model, "_is_ensemble", False) def t_batch_mode_transform( @@ -255,7 +257,7 @@ def decorated( # add t-batch dim X = X if X.dim() > 2 else X.unsqueeze(0) output = method(acqf, X, *args, **kwargs) - if hasattr(acqf, "model") and is_fully_bayesian(acqf.model): + if hasattr(acqf, "model") and is_ensemble(acqf.model): # IDEA: this could be wrapped into SampleReducingMCAcquisitionFunction output = ( output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1) diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index cfa242d5d7..43e9745743 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -50,7 +50,7 @@ SaasPyroModel, ) from botorch.models.transforms import Normalize, Standardize -from botorch.posteriors.fully_bayesian import batched_bisect, FullyBayesianPosterior +from botorch.posteriors.fully_bayesian import batched_bisect, GaussianMixturePosterior from botorch.sampling.get_sampler import get_sampler from botorch.utils.datasets import SupervisedDataset from botorch.utils.multi_objective.box_decompositions.non_dominated import ( @@ -246,7 +246,7 @@ def test_fit_model(self): for batch_shape in [[5], [6, 5, 2]]: test_X = torch.rand(*batch_shape, d, **tkwargs) posterior = model.posterior(test_X) - self.assertIsInstance(posterior, FullyBayesianPosterior) + self.assertIsInstance(posterior, GaussianMixturePosterior) # Mean/variance expected_shape = ( *batch_shape[: MCMC_DIM + 2], @@ -689,7 +689,7 @@ def f(x): variance = torch.rand(1, 5, **tkwargs) covar = torch.diag_embed(variance) mvn = MultivariateNormal(mean, to_linear_operator(covar)) - posterior = FullyBayesianPosterior(distribution=mvn) + posterior = GaussianMixturePosterior(distribution=mvn) dist = torch.distributions.Normal( loc=mean.unsqueeze(-1), scale=variance.unsqueeze(-1).sqrt() ) diff --git a/test/models/test_fully_bayesian_multitask.py b/test/models/test_fully_bayesian_multitask.py index 0cc9c25f73..66812fdc10 100644 --- a/test/models/test_fully_bayesian_multitask.py +++ b/test/models/test_fully_bayesian_multitask.py @@ -36,7 +36,7 @@ ) from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize -from botorch.posteriors import FullyBayesianPosterior +from botorch.posteriors import GaussianMixturePosterior from botorch.sampling.get_sampler import get_sampler from botorch.sampling.normal import IIDNormalSampler from botorch.utils.multi_objective.box_decompositions.non_dominated import ( @@ -252,12 +252,12 @@ def test_fit_model( for batch_shape in [[5], [5, 2], [5, 2, 6]]: test_X = torch.rand(*batch_shape, d, **tkwargs) posterior = model.posterior(test_X) - self.assertIsInstance(posterior, FullyBayesianPosterior) - self.assertIsInstance(posterior, FullyBayesianPosterior) + self.assertIsInstance(posterior, GaussianMixturePosterior) + self.assertIsInstance(posterior, GaussianMixturePosterior) test_X = torch.rand(*batch_shape, d, **tkwargs) posterior = model.posterior(test_X) - self.assertIsInstance(posterior, FullyBayesianPosterior) + self.assertIsInstance(posterior, GaussianMixturePosterior) # Mean/variance expected_shape = ( *batch_shape[: MCMC_DIM + 2], diff --git a/test/utils/test_gp_sampling.py b/test/utils/test_gp_sampling.py index 9d17453aad..cba9fc7674 100644 --- a/test/utils/test_gp_sampling.py +++ b/test/utils/test_gp_sampling.py @@ -27,7 +27,7 @@ RandomFourierFeatures, ) from botorch.utils.testing import BotorchTestCase -from botorch.utils.transforms import is_fully_bayesian +from botorch.utils.transforms import is_ensemble from gpytorch.kernels import MaternKernel, PeriodicKernel, RBFKernel, ScaleKernel from torch.distributions import MultivariateNormal @@ -686,7 +686,7 @@ def test_with_saas_models(self): num_outputs=1, n_samples=1, ) - self.assertTrue(is_fully_bayesian(gp_samples)) + self.assertTrue(is_ensemble(gp_samples)) # Non-batch evaluation. samples = gp_samples(torch.rand(2, 4, **tkwargs)) self.assertEqual(samples.shape, torch.Size([4, 2, 1])) diff --git a/test/utils/test_transforms.py b/test/utils/test_transforms.py index 99df59669c..105388df1f 100644 --- a/test/utils/test_transforms.py +++ b/test/utils/test_transforms.py @@ -9,16 +9,22 @@ import torch from botorch.models import ( GenericDeterministicModel, + HigherOrderGP, ModelList, - ModelListGP, + PairwiseGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP, ) +from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP +from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP +from botorch.models.gp_regression_mixed import MixedSingleTaskGP from botorch.models.model import Model +from botorch.models.multitask import MultiTaskGP from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from botorch.utils.transforms import ( _verify_output_shape, concatenate_pending_points, + is_ensemble, is_fully_bayesian, match_batch_shape, normalize, @@ -299,26 +305,82 @@ def test_normalize_indices(self): class TestIsFullyBayesian(BotorchTestCase): def test_is_fully_bayesian(self): X, Y = torch.rand(3, 2), torch.randn(3, 1) - saas = SaasFullyBayesianSingleTaskGP(train_X=X, train_Y=Y) vanilla_gp = SingleTaskGP(train_X=X, train_Y=Y) deterministic = GenericDeterministicModel(f=lambda x: x) - # Single model - self.assertTrue(is_fully_bayesian(model=saas)) - self.assertFalse(is_fully_bayesian(model=vanilla_gp)) - self.assertFalse(is_fully_bayesian(model=deterministic)) - # ModelListGP - self.assertTrue(is_fully_bayesian(model=ModelListGP(saas, saas))) - self.assertTrue(is_fully_bayesian(model=ModelListGP(saas, vanilla_gp))) - self.assertFalse(is_fully_bayesian(model=ModelListGP(vanilla_gp, vanilla_gp))) - # ModelList - self.assertTrue(is_fully_bayesian(model=ModelList(saas, saas))) - self.assertTrue(is_fully_bayesian(model=ModelList(saas, deterministic))) - self.assertFalse(is_fully_bayesian(model=ModelList(vanilla_gp, deterministic))) - # Nested ModelList - self.assertTrue(is_fully_bayesian(model=ModelList(ModelList(saas), saas))) - self.assertTrue( - is_fully_bayesian(model=ModelList(ModelList(saas), deterministic)) + + fully_bayesian_models = ( + SaasFullyBayesianSingleTaskGP(train_X=X, train_Y=Y), + SaasFullyBayesianMultiTaskGP(train_X=X, train_Y=Y, task_feature=-1), + ) + for m in fully_bayesian_models: + self.assertTrue(is_fully_bayesian(model=m)) + # ModelList + self.assertTrue(is_fully_bayesian(model=ModelList(m, m))) + self.assertTrue(is_fully_bayesian(model=ModelList(m, vanilla_gp))) + self.assertTrue(is_fully_bayesian(model=ModelList(m, deterministic))) + # Nested ModelList + self.assertTrue(is_fully_bayesian(model=ModelList(ModelList(m), m))) + self.assertTrue( + is_fully_bayesian(model=ModelList(ModelList(m), deterministic)) + ) + + non_fully_bayesian_models = ( + GenericDeterministicModel(f=lambda x: x), + SingleTaskGP(train_X=X, train_Y=Y), + MultiTaskGP(train_X=X, train_Y=Y, task_feature=-1), + HigherOrderGP(train_X=X, train_Y=Y), + SingleTaskMultiFidelityGP(train_X=X, train_Y=Y, data_fidelity=3), + MixedSingleTaskGP(train_X=X, train_Y=Y, cat_dims=[1]), + PairwiseGP(datapoints=X, comparisons=None), + ) + for m in non_fully_bayesian_models: + self.assertFalse(is_fully_bayesian(model=m)) + # ModelList + self.assertFalse(is_fully_bayesian(model=ModelList(m, m))) + self.assertFalse(is_fully_bayesian(model=ModelList(m, vanilla_gp))) + self.assertFalse(is_fully_bayesian(model=ModelList(m, deterministic))) + # Nested ModelList + self.assertFalse(is_fully_bayesian(model=ModelList(ModelList(m), m))) + self.assertFalse( + is_fully_bayesian(model=ModelList(ModelList(m), deterministic)) + ) + + +class TestIsEnsemble(BotorchTestCase): + def test_is_ensemble(self): + X, Y = torch.rand(3, 2), torch.randn(3, 1) + vanilla_gp = SingleTaskGP(train_X=X, train_Y=Y) + deterministic = GenericDeterministicModel(f=lambda x: x) + + ensemble_models = ( + SaasFullyBayesianSingleTaskGP(train_X=X, train_Y=Y), + SaasFullyBayesianMultiTaskGP(train_X=X, train_Y=Y, task_feature=-1), ) - self.assertFalse( - is_fully_bayesian(model=ModelList(ModelList(vanilla_gp), deterministic)) + for m in ensemble_models: + self.assertTrue(is_ensemble(model=m)) + # ModelList + self.assertTrue(is_ensemble(model=ModelList(m, m))) + self.assertTrue(is_ensemble(model=ModelList(m, vanilla_gp))) + self.assertTrue(is_ensemble(model=ModelList(m, deterministic))) + # Nested ModelList + self.assertTrue(is_ensemble(model=ModelList(ModelList(m), m))) + self.assertTrue(is_ensemble(model=ModelList(ModelList(m), deterministic))) + + non_ensemble_models = ( + GenericDeterministicModel(f=lambda x: x), + SingleTaskGP(train_X=X, train_Y=Y), + MultiTaskGP(train_X=X, train_Y=Y, task_feature=-1), + HigherOrderGP(train_X=X, train_Y=Y), + SingleTaskMultiFidelityGP(train_X=X, train_Y=Y, data_fidelity=3), + MixedSingleTaskGP(train_X=X, train_Y=Y, cat_dims=[1]), + PairwiseGP(datapoints=X, comparisons=None), ) + for m in non_ensemble_models: + self.assertFalse(is_ensemble(model=m)) + # ModelList + self.assertFalse(is_ensemble(model=ModelList(m, m))) + self.assertFalse(is_ensemble(model=ModelList(m, vanilla_gp))) + self.assertFalse(is_ensemble(model=ModelList(m, deterministic))) + # Nested ModelList + self.assertFalse(is_ensemble(model=ModelList(ModelList(m), m))) + self.assertFalse(is_ensemble(model=ModelList(ModelList(m), deterministic)))