Skip to content

Commit

Permalink
Remove Surrogate.from_botorch
Browse files Browse the repository at this point in the history
Summary: This is buggy and unsupported.

Reviewed By: Balandat

Differential Revision: D50581677
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 24, 2023
1 parent 2762e91 commit addf6a3
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 119 deletions.
51 changes: 7 additions & 44 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UserInputError
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.botorch_modular.input_constructors.covar_modules import (
covar_module_argparse,
Expand Down Expand Up @@ -159,10 +159,6 @@ def __init__(
self._training_data: Optional[List[SupervisedDataset]] = None
self._outcomes: Optional[List[str]] = None
self._model: Optional[Model] = None
# Special setting for surrogates instantiated via `Surrogate.from_botorch`,
# to avoid re-constructing the underlying BoTorch model on `Surrogate.fit`
# when set to `False`.
self._constructed_manually: bool = False

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -212,22 +208,6 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.training_data[0].X.device

@classmethod
def from_botorch(
cls,
model: Model,
mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
) -> Surrogate:
"""Instantiate a `Surrogate` from a pre-instantiated Botorch `Model`."""
surrogate = cls(botorch_model_class=model.__class__, mll_class=mll_class)
surrogate._model = model
# Temporarily disallowing `update` for surrogates instantiated from
# pre-made BoTorch `Model` instances to avoid reconstructing models
# that were likely pre-constructed for a reason (e.g. if this setup
# doesn't fully allow to constuct them).
surrogate._constructed_manually = True
return surrogate

def clone_reset(self) -> Surrogate:
return self.__class__(**self._serialize_attributes_as_kwargs())

Expand All @@ -247,9 +227,6 @@ def construct(
search_space_digest: Information about the search space used for
inferring suitable botorch model class.
"""
if self._constructed_manually:
logger.warning("Reconstructing a manually constructed `Model`.")

# To determine whether to use ModelList under the hood, we need to check for
# the batched multi-output case, so we first see which model would be chosen
# given the Yvars and the properties of data.
Expand Down Expand Up @@ -535,19 +512,12 @@ def fit(
state_dict: Optional state dict to load.
refit: Whether to re-optimize model parameters.
"""
if self._constructed_manually:
logger.debug(
"For manually constructed surrogates (via `Surrogate.from_botorch`), "
"`fit` skips setting the training data on model and only reoptimizes "
"its parameters if `refit=True`."
)
else:
self.construct(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
)
self._outcomes = metric_names
self.construct(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
)
self._outcomes = metric_names

if state_dict:
self.model.load_state_dict(not_none(state_dict))
Expand Down Expand Up @@ -662,13 +632,6 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]:
"""Serialize attributes of this surrogate, to be passed back to it
as kwargs on reinstantiation.
"""
if self._constructed_manually:
raise UnsupportedError(
"Surrogates constructed manually (ie Surrogate.from_botorch) may not "
"be serialized. If serialization is necessary please initialize from "
"the constructor."
)

return {
"botorch_model_class": self.botorch_model_class,
"model_options": self.model_options,
Expand Down
19 changes: 1 addition & 18 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import torch
from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UserInputError
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model
Expand Down Expand Up @@ -304,15 +304,6 @@ def test_device_property(self) -> None:
)
self.assertEqual(self.device, surrogate.device)

def test_from_botorch(self) -> None:
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
surrogate_kwargs = botorch_model_class.construct_inputs(
self.training_data[0]
)
surrogate = Surrogate.from_botorch(botorch_model_class(**surrogate_kwargs))
self.assertIsInstance(surrogate.model, botorch_model_class)
self.assertTrue(surrogate._constructed_manually)

@patch(f"{CURRENT_PATH}.SaasFullyBayesianSingleTaskGP.__init__", return_value=None)
@patch(f"{CURRENT_PATH}.SingleTaskGP.__init__", return_value=None)
def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None:
Expand All @@ -338,7 +329,6 @@ def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None:
call_kwargs = mock_GPs[i].call_args[1]
self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0]))
self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0]))
self.assertFalse(surrogate._constructed_manually)

# Check that `model_options` passed to the `Surrogate` constructor are
# properly propagated.
Expand Down Expand Up @@ -591,13 +581,6 @@ def test_serialize_attributes_as_kwargs(self) -> None:
}
self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)

with self.assertRaisesRegex(
UnsupportedError, "Surrogates constructed manually"
):
surrogate, _ = self._get_surrogate(botorch_model_class=SingleTaskGP)
surrogate._constructed_manually = True
surrogate._serialize_attributes_as_kwargs()

def test_w_robust_digest(self) -> None:
surrogate = Surrogate(
botorch_model_class=SingleTaskGP,
Expand Down
23 changes: 0 additions & 23 deletions tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,29 +299,6 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "f3c267de-f9b2-4524-852b-156fc47d1745"
},
"source": [
"Alternatively, for BoTorch `Model`-s that require complex instantiation procedures, leverage the `from_BoTorch` instantiation method of `Surrogate`:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"originalKey": "905ba2bc-37e4-4a12-8442-1c772ffd15d0"
},
"outputs": [],
"source": [
"surrogate_from_botorch_model = Surrogate.from_BoTorch(\n",
" model=..., # BoTorch `Model` instance, with training data already set\n",
" mll_class=ExactMarginalLogLikelihood, # Optional, MLL class with which to optimize model parameters\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
36 changes: 2 additions & 34 deletions tutorials/modular_botax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
"1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**\n",
" 1. Example with minimal options that uses the defaults\n",
" 2. Example showing all possible options\n",
" 3. Using a pre-constructed BoTorch Model (e.g. in research or development)\n",
" 4. Surrogate and Acquisition Q&A\n",
" 3. Surrogate and Acquisition Q&A\n",
"2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**\n",
" 1. Making a `Surrogate` from BoTorch `Model`\n",
" 2. Using an arbitrary BoTorch `AcquisitionFunction` in Ax\n",
Expand Down Expand Up @@ -269,45 +268,14 @@
")"
]
},
{
"cell_type": "markdown",
"id": "critical-receptor",
"metadata": {
"originalKey": "5b15f6d8-27a0-410e-95ff-4a304bf35498"
},
"source": [
"## 2C. `Surrogate` from pre-instantiated BoTorch `Model`\n",
"\n",
"Alternatively, for BoTorch `Model`-s that require complex instantiation procedures (or is in development stage), leverage the `from_botorch` instantiation method of Surrogate:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fourth-moore",
"metadata": {
"originalKey": "ce873686-cedd-4f3a-9476-b6e78f1c3650"
},
"outputs": [],
"source": [
"from_botorch_model = BoTorchModel(\n",
" surrogate=Surrogate.from_botorch(\n",
" # Pre-constructed BoTorch `Model` instance, with training data already set\n",
" model=...,\n",
" # Optional, MLL class with which to optimize model parameters\n",
" mll_class=ExactMarginalLogLikelihood,\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "fourth-material",
"metadata": {
"originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e"
},
"source": [
"## 2D. `Surrogate` and `Acquisition` Q&A\n",
"## 2C. `Surrogate` and `Acquisition` Q&A\n",
"\n",
"**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.\n",
"\n",
Expand Down

0 comments on commit addf6a3

Please sign in to comment.