Skip to content

Commit

Permalink
move deprecation warnings to SurrogateSpec and pass surrogate_spec to…
Browse files Browse the repository at this point in the history
… Surrogate (#3025)

Summary:
Pull Request resolved: #3025

See title. This avoids defining the same arguments in two places, per feedback from Sait. Moving deprecation warnings to SurrogateSpec means warnings are raised raised while specifying the model rather than when it is instantiated.

Reviewed By: saitcakmak

Differential Revision: D65321401

fbshipit-source-id: d561a415d192a398531fbcf2bdbee69c83e6b7bc
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 6, 2024
1 parent f13ce47 commit 647bdc5
Show file tree
Hide file tree
Showing 14 changed files with 685 additions and 597 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.model import SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.service.scheduler import SchedulerOptions
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import LogExpectedImprovement
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/tests/methods/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _test_mbm_acquisition(self, scheduler_options: SchedulerOptions) -> None:
self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient)
surrogate_spec = model_kwargs["surrogate_spec"]
self.assertEqual(
surrogate_spec.botorch_model_class.__name__,
surrogate_spec.model_configs[0].botorch_model_class.__name__,
"SingleTaskGP",
)

Expand Down
6 changes: 2 additions & 4 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@
from ax.models.random.sobol import SobolGenerator
from ax.models.random.uniform import UniformGenerator
from ax.models.torch.botorch import BotorchModel
from ax.models.torch.botorch_modular.model import (
BoTorchModel as ModularBoTorchModel,
SurrogateSpec,
)
from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.models.torch.cbo_sac import SACBO
from ax.utils.common.kwargs import (
Expand Down
6 changes: 3 additions & 3 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from ax.models.random.sobol import SobolGenerator
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.utils.common.kwargs import get_function_argument_names
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_SAASBO(self) -> None:
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
saasbo.model.surrogate.model_configs[0].botorch_model_class,
saasbo.model.surrogate.surrogate_spec.model_configs[0].botorch_model_class,
SaasFullyBayesianSingleTaskGP,
)

Expand Down
73 changes: 7 additions & 66 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import warnings
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any

import numpy.typing as npt
Expand All @@ -23,12 +22,11 @@
get_rounding_func,
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.models.torch.botorch_modular.utils import (
check_outcome_dataset_match,
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
ModelConfig,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
Expand All @@ -38,53 +36,10 @@
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.model import Model
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor


@dataclass(frozen=True)
class SurrogateSpec:
"""
Fields in the SurrogateSpec dataclass correspond to arguments in
``Surrogate.__init__``, except for ``outcomes`` which is used to specify which
outcomes the Surrogate is responsible for modeling.
When ``BotorchModel.fit`` is called, these fields will be used to construct the
requisite Surrogate objects.
If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate.
"""

botorch_model_class: type[Model] | None = None
botorch_model_kwargs: dict[str, Any] = field(default_factory=dict)

mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood
mll_kwargs: dict[str, Any] = field(default_factory=dict)

covar_module_class: type[Kernel] | None = None
covar_module_kwargs: dict[str, Any] | None = None

likelihood_class: type[Likelihood] | None = None
likelihood_kwargs: dict[str, Any] | None = None

input_transform_classes: list[type[InputTransform]] | None = None
input_transform_options: dict[str, dict[str, Any]] | None = None

outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] | None = None

allow_batched_models: bool = True

model_configs: list[ModelConfig] = field(default_factory=list)
metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
outcomes: list[str] = field(default_factory=list)


class BoTorchModel(TorchModel, Base):
"""**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
Expand Down Expand Up @@ -230,31 +185,17 @@ def fit(

# If a surrogate has not been constructed, construct it.
if self._surrogate is None:
if (spec := self.surrogate_spec) is not None:
self._surrogate = Surrogate(
botorch_model_class=spec.botorch_model_class,
model_options=spec.botorch_model_kwargs,
mll_class=spec.mll_class,
mll_options=spec.mll_kwargs,
covar_module_class=spec.covar_module_class,
covar_module_options=spec.covar_module_kwargs,
likelihood_class=spec.likelihood_class,
likelihood_options=spec.likelihood_kwargs,
input_transform_classes=spec.input_transform_classes,
input_transform_options=spec.input_transform_options,
outcome_transform_classes=spec.outcome_transform_classes,
outcome_transform_options=spec.outcome_transform_options,
model_configs=spec.model_configs,
metric_to_model_configs=spec.metric_to_model_configs,
allow_batched_models=spec.allow_batched_models,
)
if self.surrogate_spec is not None:
self._surrogate = Surrogate(surrogate_spec=self.surrogate_spec)
else:
self._surrogate = Surrogate()

# Fit the surrogate.
for config in self.surrogate.model_configs:
for config in self.surrogate.surrogate_spec.model_configs:
config.model_options.update(additional_model_inputs)
for config_list in self.surrogate.metric_to_model_configs.values():
for (
config_list
) in self.surrogate.surrogate_spec.metric_to_model_configs.values():
for config in config_list:
config.model_options.update(additional_model_inputs)
self.surrogate.fit(
Expand Down
Loading

0 comments on commit 647bdc5

Please sign in to comment.