Skip to content

Commit

Permalink
MBM Surrogate: Clean up model fitting behavior and **kwargs usage (fa…
Browse files Browse the repository at this point in the history
…cebook#1882)

Summary:

Kwargs: These were originally added to support passing around additional kwargs in subclasses that no-longer exist. They later silently took on the role of carrying the kwargs that gets passed down to model input constructors. The argument name has been updated with added docstring explaining what these do. These are now used to update `Surrogate.model_options` and passed to the input constructors from there.

Model fitting clean up: We had a bunch of duplicate logic for constructing the models, left over from the merger of `Surrogate` & `ListSurrogate`, which led to several bugs in the past and made the code much harder to maintain and review. This diff deduplicates and simplifies the model fitting logic.

Differential Revision: D49707895
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 6, 2023
1 parent dd10bee commit 7039fc7
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 342 deletions.
8 changes: 3 additions & 5 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import FixedNoiseMultiTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.utils.types import DEFAULT
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
if use_saas
else [
Surrogate(
botorch_model_class=FixedNoiseMultiTaskGP,
botorch_model_class=MultiTaskGP,
mll_class=ExactMarginalLogLikelihood,
covar_module_class=ScaleMaternKernel,
covar_module_options={
Expand Down Expand Up @@ -514,9 +514,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
for i in range(len(models)):
self.assertIsInstance(
models[i],
SaasFullyBayesianMultiTaskGP
if use_saas
else FixedNoiseMultiTaskGP,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
)
if use_saas is False:
self.assertIsInstance(models[i].covar_module, ScaleKernel)
Expand Down
32 changes: 9 additions & 23 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ax.models.torch.botorch_modular.utils import (
choose_botorch_acqf_class,
construct_acquisition_and_optimizer_options,
convert_to_block_design,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
Expand All @@ -33,7 +32,6 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models import ModelList
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.model import Model
from botorch.models.transforms.input import InputTransform
Expand Down Expand Up @@ -247,7 +245,7 @@ def fit(
# state dict by surrogate label
state_dicts: Optional[Mapping[str, Dict[str, Tensor]]] = None,
refit: bool = True,
**kwargs: Any,
**additional_model_inputs: Any,
) -> None:
"""Fit model to m outcomes.
Expand All @@ -264,6 +262,8 @@ def fit(
surrogate_specs. If using a single, pre-instantiated model use
`Keys.ONLY_SURROGATE.
refit: Whether to re-optimize model parameters.
additional_model_inputs: Additional kwargs to pass to the
model input constructor in ``Surrogate.fit``.
"""

if len(datasets) != len(metric_names):
Expand All @@ -288,7 +288,7 @@ def fit(
if state_dicts
else None,
refit=refit,
**kwargs,
additional_model_inputs=additional_model_inputs,
)
return

Expand Down Expand Up @@ -340,20 +340,6 @@ def fit(
datasets_by_metric_name[metric_name]
for metric_name in subset_metric_names
]
if (
len(subset_datasets) > 1
# if Surrogate's model is none a ModelList will be autoset
and surrogate._model is not None
and not isinstance(surrogate.model, ModelList)
):
# Note: If the datasets do not confirm to a block design then this
# will filter the data and drop observations to make sure that it does.
# This can happen e.g. if only some metrics are observed at some points
subset_datasets, metric_names = convert_to_block_design(
datasets=subset_datasets,
metric_names=metric_names,
force=True,
)

surrogate.fit(
datasets=subset_datasets,
Expand All @@ -362,7 +348,7 @@ def fit(
candidate_metadata=candidate_metadata,
state_dict=(state_dicts or {}).get(label),
refit=refit,
**kwargs,
additional_model_inputs=additional_model_inputs,
)

@copy_doc(TorchModel.update)
Expand All @@ -372,7 +358,7 @@ def update(
metric_names: List[str],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
**kwargs: Any,
**additional_model_inputs: Any,
) -> None:
if len(self.surrogates) == 0:
raise UnsupportedError("Cannot update model that has not been fitted.")
Expand Down Expand Up @@ -417,7 +403,7 @@ def update(
candidate_metadata=candidate_metadata,
state_dict=state_dict,
refit=self.refit_on_update,
**kwargs,
additional_model_inputs=additional_model_inputs,
)

@single_surrogate_only
Expand Down Expand Up @@ -536,7 +522,7 @@ def cross_validate(
metric_names: List[str],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
**kwargs: Any,
**additional_model_inputs: Any,
) -> Tuple[Tensor, Tensor]:
# Will fail if metric_names exist across multiple models
surrogate_labels = (
Expand Down Expand Up @@ -589,7 +575,7 @@ def cross_validate(
search_space_digest=search_space_digest,
state_dicts=state_dicts,
refit=self.refit_on_cv,
**kwargs,
**additional_model_inputs,
)
X_test_prediction = self.predict_from_surrogate(
surrogate_label=surrogate_label, X=X_test
Expand Down
Loading

0 comments on commit 7039fc7

Please sign in to comment.