Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move deprecation warnings to SurrogateSpec and pass surrogate_spec to Surrogate #3025

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.model import SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.service.scheduler import SchedulerOptions
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import LogExpectedImprovement
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/tests/methods/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _test_mbm_acquisition(self, scheduler_options: SchedulerOptions) -> None:
self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient)
surrogate_spec = model_kwargs["surrogate_spec"]
self.assertEqual(
surrogate_spec.botorch_model_class.__name__,
surrogate_spec.model_configs[0].botorch_model_class.__name__,
"SingleTaskGP",
)

Expand Down
6 changes: 2 additions & 4 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@
from ax.models.random.sobol import SobolGenerator
from ax.models.random.uniform import UniformGenerator
from ax.models.torch.botorch import BotorchModel
from ax.models.torch.botorch_modular.model import (
BoTorchModel as ModularBoTorchModel,
SurrogateSpec,
)
from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.models.torch.cbo_sac import SACBO
from ax.utils.common.kwargs import (
Expand Down
19 changes: 14 additions & 5 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import (
Expand All @@ -27,8 +28,8 @@
from ax.models.random.sobol import SobolGenerator
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.utils.common.kwargs import get_function_argument_names
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -99,7 +100,8 @@ def test_SAASBO(self) -> None:
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP
saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class,
SaasFullyBayesianSingleTaskGP,
)

@mock_botorch_optimize
Expand Down Expand Up @@ -459,9 +461,16 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
self.assertIsInstance(mtgp, TorchModelBridge)
self.assertIsInstance(mtgp.model, BoTorchModel)
self.assertEqual(mtgp.model.acquisition_class, Acquisition)
self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
is_moo = isinstance(
exp.optimization_config, MultiObjectiveOptimizationConfig
)
if is_moo:
self.assertIsInstance(mtgp.model.surrogate.model, ModelListGP)
models = mtgp.model.surrogate.model.models
else:
models = [mtgp.model.surrogate.model]

for model in mtgp.model.surrogate.model.models:
for model in models:
self.assertIsInstance(
model,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
Expand Down
70 changes: 10 additions & 60 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import warnings
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any

import numpy.typing as npt
Expand All @@ -23,7 +22,7 @@
get_rounding_func,
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
check_outcome_dataset_match,
choose_botorch_acqf_class,
Expand All @@ -37,51 +36,10 @@
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.model import Model
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor


@dataclass(frozen=True)
class SurrogateSpec:
"""
Fields in the SurrogateSpec dataclass correspond to arguments in
``Surrogate.__init__``, except for ``outcomes`` which is used to specify which
outcomes the Surrogate is responsible for modeling.
When ``BotorchModel.fit`` is called, these fields will be used to construct the
requisite Surrogate objects.
If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate.
"""

botorch_model_class: type[Model] | None = None
botorch_model_kwargs: dict[str, Any] = field(default_factory=dict)

mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_kwargs: dict[str, Any] = field(default_factory=dict)

covar_module_class: type[Kernel] | None = None
covar_module_kwargs: dict[str, Any] | None = None

likelihood_class: type[Likelihood] | None = None
likelihood_kwargs: dict[str, Any] | None = None

input_transform_classes: list[type[InputTransform]] | None = None
input_transform_options: dict[str, dict[str, Any]] | None = None

outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] | None = None

allow_batched_models: bool = True

outcomes: list[str] = field(default_factory=list)


class BoTorchModel(TorchModel, Base):
"""**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
Expand Down Expand Up @@ -227,27 +185,19 @@ def fit(

# If a surrogate has not been constructed, construct it.
if self._surrogate is None:
if (spec := self.surrogate_spec) is not None:
self._surrogate = Surrogate(
botorch_model_class=spec.botorch_model_class,
model_options=spec.botorch_model_kwargs,
mll_class=spec.mll_class,
mll_options=spec.mll_kwargs,
covar_module_class=spec.covar_module_class,
covar_module_options=spec.covar_module_kwargs,
likelihood_class=spec.likelihood_class,
likelihood_options=spec.likelihood_kwargs,
input_transform_classes=spec.input_transform_classes,
input_transform_options=spec.input_transform_options,
outcome_transform_classes=spec.outcome_transform_classes,
outcome_transform_options=spec.outcome_transform_options,
allow_batched_models=spec.allow_batched_models,
)
if self.surrogate_spec is not None:
self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec)
else:
self._surrogate = Surrogate()

# Fit the surrogate.
self.surrogate.model_options.update(additional_model_inputs)
for config in self.surrogate.surrogate_spec.model_configs:
config.model_options.update(additional_model_inputs)
for (
config_list
) in self.surrogate.surrogate_spec.metric_to_model_configs.values():
for config in config_list:
config.model_options.update(additional_model_inputs)
self.surrogate.fit(
datasets=datasets,
search_space_digest=search_space_digest,
Expand Down
Loading