Skip to content

Commit

Permalink
Support per-metric model specification in MBM
Browse files Browse the repository at this point in the history
Summary:
Enables using different models for different metrics.

* Adds ModelConfig dataclass to specify a single botorch model configuration
* Adds a list of ModelConfigs to SurrogateSpec
* Adds a dictionary mapping metric names to list of ModelConfigs to enable per-metric model specification
* Lists of model configs are used to enable per-metric model selection across multiple ModelConfigs in a subsequent diff.

Reviewed By: saitcakmak

Differential Revision: D64793595
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 1, 2024
1 parent fcfda95 commit 7fb4c59
Show file tree
Hide file tree
Showing 10 changed files with 1,417 additions and 274 deletions.
15 changes: 12 additions & 3 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import (
Expand Down Expand Up @@ -99,7 +100,8 @@ def test_SAASBO(self) -> None:
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP
saasbo.model.surrogate.model_configs[0].botorch_model_class,
SaasFullyBayesianSingleTaskGP,
)

@mock_botorch_optimize
Expand Down Expand Up @@ -459,9 +461,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
self.assertIsInstance(mtgp, TorchModelBridge)
self.assertIsInstance(mtgp.model, BoTorchModel)
self.assertEqual(mtgp.model.acquisition_class, Acquisition)
self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
is_moo = isinstance(
exp.optimization_config, MultiObjectiveOptimizationConfig
)
if is_moo:
self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
models = mtgp.model.surrogate.model.models
else:
models = [mtgp.model.surrogate.model]

for model in mtgp.model.surrogate.model.models:
for model in models:
self.assertIsInstance(
model,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
Expand Down
11 changes: 10 additions & 1 deletion ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
check_outcome_dataset_match,
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
Expand Down Expand Up @@ -79,6 +80,8 @@ class SurrogateSpec:

allow_batched_models: bool = True

model_configs: list[ModelConfig] = field(default_factory=list)
metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
outcomes: list[str] = field(default_factory=list)


Expand Down Expand Up @@ -241,13 +244,19 @@ def fit(
input_transform_options=spec.input_transform_options,
outcome_transform_classes=spec.outcome_transform_classes,
outcome_transform_options=spec.outcome_transform_options,
model_configs=spec.model_configs,
metric_to_model_configs=spec.metric_to_model_configs,
allow_batched_models=spec.allow_batched_models,
)
else:
self._surrogate = Surrogate()

# Fit the surrogate.
self.surrogate.model_options.update(additional_model_inputs)
for config in self.surrogate.model_configs:
config.model_options.update(additional_model_inputs)
for config_list in self.surrogate.metric_to_model_configs.values():
for config in config_list:
config.model_options.update(additional_model_inputs)
self.surrogate.fit(
datasets=datasets,
search_space_digest=search_space_digest,
Expand Down
222 changes: 160 additions & 62 deletions ax/models/torch/botorch_modular/surrogate.py

Large diffs are not rendered by default.

118 changes: 114 additions & 4 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from dataclasses import dataclass, field
from logging import Logger
from typing import Any

Expand All @@ -34,29 +35,138 @@
from botorch.models.model import Model, ModelList
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor

MIN_OBSERVED_NOISE_LEVEL = 1e-7
logger: Logger = get_logger(__name__)


@dataclass
class ModelConfig:
"""Configuration for the BoTorch Model used in Surrogate.
Args:
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model. If None is provided a model class will be selected (either
one for all outcomes or a ModelList with separate models for each outcome)
will be selected automatically based off the datasets at `construct` time.
This argument is deprecated in favor of model_configs.
model_options: Dictionary of options / kwargs for the BoTorch
``Model`` constructed during ``Surrogate.fit``.
Note that the corresponding attribute will later be updated to include any
additional kwargs passed into ``BoTorchModel.fit``.
This argument is deprecated in favor of model_configs.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
This argument is deprecated in favor of model_configs.
mll_options: Dictionary of options / kwargs for the MLL. This argument is
deprecated in favor of model_configs.
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``. This argument is deprecated in
favor of model_configs.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`. This argument
is deprecated in favor of model_configs.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
This argument is deprecated in favor of model_configs.
input_transform_options: Input transform classes kwargs. The keys are
class string names and the values are dictionaries of input transform
kwargs. For example,
`
input_transform_classes = [Normalize, Round]
input_transform_options = {
"Normalize": {"d": 3},
"Round": {"integer_indices": [0], "categorical_features": {1: 2}},
}
`
For more input options see `botorch/models/transforms/input.py`.
This argument is deprecated in favor of model_configs.
covar_module_class: Covariance module class. This gets initialized after
parsing the ``covar_module_options`` in ``covar_module_argparse``,
and gets passed to the model constructor as ``covar_module``.
This argument is deprecated in favor of model_configs.
covar_module_options: Covariance module kwargs. This argument is deprecated
in favor of model_configs.
likelihood: ``Likelihood`` class. This gets initialized with
``likelihood_options`` and gets passed to the model constructor.
This argument is deprecated in favor of model_configs.
likelihood_options: Likelihood options. This argument is deprecated in favor
of model_configs.
"""

botorch_model_class: type[Model] | None = None
model_options: dict[str, Any] = field(default_factory=dict)
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_options: dict[str, Any] = field(default_factory=dict)
input_transform_classes: list[type[InputTransform]] | None = None
input_transform_options: dict[str, dict[str, Any]] | None = field(
default_factory=dict
)
outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict)
covar_module_class: type[Kernel] | None = None
covar_module_options: dict[str, Any] = field(default_factory=dict)
likelihood_class: type[Likelihood] | None = None
likelihood_options: dict[str, Any] = field(default_factory=dict)


def use_model_list(
datasets: Sequence[SupervisedDataset],
botorch_model_class: type[Model],
model_configs: list[ModelConfig] | None = None,
metric_to_model_configs: dict[str, list[ModelConfig]] | None = None,
allow_batched_models: bool = True,
) -> bool:
if issubclass(botorch_model_class, MultiTaskGP):
# We currently always wrap multi-task models into `ModelListGP`.
model_configs = model_configs or []
metric_to_model_configs = metric_to_model_configs or {}
if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1:
# There is only one outcome, so we can use a single model.
return False
elif (
len(model_configs) > 1
or len(metric_to_model_configs) > 0
or any(len(model_config) for model_config in metric_to_model_configs.values())
):
# There are multiple outcomes and outcomes might be modeled with different
# models
return True
elif issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP):
# Otherwise, the same model class is used for all outcomes.
# Determine what the model class is.
if len(model_configs) > 0:
botorch_model_class = (
model_configs[0].botorch_model_class or botorch_model_class
)
if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP):
# SAAS models do not support multiple outcomes.
# Use model list if there are multiple outcomes.
return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1
elif issubclass(botorch_model_class, MultiTaskGP):
# We wrap multi-task models into `ModelListGP` when there are
# multiple outcomes.
return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1
elif len(datasets) == 1:
# Just one outcome, can use single model.
# This method is called before multiple datasets are merged into
# one if using a batched model. If there is one dataset here,
# there should be a reason that a single model should be used:
# e.g. a contextual model, where we want to jointly model the metric
# each context (and context-level metrics are different outcomes).
return False
elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all(
torch.equal(datasets[0].X, ds.X) for ds in datasets[1:]
Expand Down
21 changes: 20 additions & 1 deletion ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ax.models.torch.botorch_modular.utils import (
choose_model_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
)
from ax.models.torch.utils import _filter_X_observed
from ax.models.torch_base import TorchOptConfig
Expand Down Expand Up @@ -327,7 +328,21 @@ def test_fit(self, mock_fit: Mock) -> None:
mock_fit.assert_called_with(
dataset=self.block_design_training_data[0],
search_space_digest=self.mf_search_space_digest,
botorch_model_class=SingleTaskMultiFidelityGP,
model_config=ModelConfig(
botorch_model_class=None,
model_options={},
mll_class=ExactMarginalLogLikelihood,
mll_options={},
input_transform_classes=None,
input_transform_options={},
outcome_transform_classes=None,
outcome_transform_options={},
covar_module_class=None,
covar_module_options={},
likelihood_class=None,
likelihood_options={},
),
default_botorch_model_class=SingleTaskMultiFidelityGP,
state_dict=None,
refit=True,
)
Expand Down Expand Up @@ -727,6 +742,8 @@ def test_surrogate_model_options_propagation(
input_transform_options=None,
outcome_transform_classes=None,
outcome_transform_options=None,
model_configs=[],
metric_to_model_configs={},
allow_batched_models=True,
)

Expand Down Expand Up @@ -755,6 +772,8 @@ def test_surrogate_options_propagation(
input_transform_options=None,
outcome_transform_classes=None,
outcome_transform_options=None,
model_configs=[],
metric_to_model_configs={},
allow_batched_models=False,
)

Expand Down
Loading

0 comments on commit 7fb4c59

Please sign in to comment.