Skip to content

Commit

Permalink
MBM Surrogate: Clean up model fitting behavior and **kwargs usage (#1882
Browse files Browse the repository at this point in the history
)

Summary:

Kwargs: These were originally added to support passing around additional kwargs in subclasses that no-longer exist. They later silently took on the role of carrying the kwargs that gets passed down to model input constructors. The argument name has been updated with added docstring explaining what these do. These are now used to update `Surrogate.model_options` and passed to the input constructors from there.

Model fitting clean up: We had a bunch of duplicate logic for constructing the models, left over from the merger of `Surrogate` & `ListSurrogate`, which led to several bugs in the past and made the code much harder to maintain and review. This diff deduplicates and simplifies the model fitting logic.

Differential Revision: D49707895
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 28, 2023
1 parent bf74e8e commit c67d597
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 326 deletions.
8 changes: 3 additions & 5 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import FixedNoiseMultiTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.utils.types import DEFAULT
from gpytorch.kernels.matern_kernel import MaternKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
if use_saas
else [
Surrogate(
botorch_model_class=FixedNoiseMultiTaskGP,
botorch_model_class=MultiTaskGP,
mll_class=ExactMarginalLogLikelihood,
covar_module_class=ScaleMaternKernel,
covar_module_options={
Expand Down Expand Up @@ -514,9 +514,7 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None:
for i in range(len(models)):
self.assertIsInstance(
models[i],
SaasFullyBayesianMultiTaskGP
if use_saas
else FixedNoiseMultiTaskGP,
SaasFullyBayesianMultiTaskGP if use_saas else MultiTaskGP,
)
if use_saas is False:
self.assertIsInstance(models[i].covar_module, ScaleKernel)
Expand Down
16 changes: 9 additions & 7 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def fit(
# state dict by surrogate label
state_dicts: Optional[Mapping[str, Dict[str, Tensor]]] = None,
refit: bool = True,
**kwargs: Any,
**additional_model_inputs: Any,
) -> None:
"""Fit model to m outcomes.
Expand All @@ -264,6 +264,8 @@ def fit(
surrogate_specs. If using a single, pre-instantiated model use
`Keys.ONLY_SURROGATE.
refit: Whether to re-optimize model parameters.
additional_model_inputs: Additional kwargs to pass to the
model input constructor in ``Surrogate.fit``.
"""

if len(datasets) != len(metric_names):
Expand All @@ -288,7 +290,7 @@ def fit(
if state_dicts
else None,
refit=refit,
**kwargs,
additional_model_inputs=additional_model_inputs,
)
return

Expand Down Expand Up @@ -362,7 +364,7 @@ def fit(
candidate_metadata=candidate_metadata,
state_dict=(state_dicts or {}).get(label),
refit=refit,
**kwargs,
additional_model_inputs=additional_model_inputs,
)

@copy_doc(TorchModel.update)
Expand All @@ -372,7 +374,7 @@ def update(
metric_names: List[str],
search_space_digest: SearchSpaceDigest,
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
**kwargs: Any,
**additional_model_inputs: Any,
) -> None:
if len(self.surrogates) == 0:
raise UnsupportedError("Cannot update model that has not been fitted.")
Expand Down Expand Up @@ -417,7 +419,7 @@ def update(
candidate_metadata=candidate_metadata,
state_dict=state_dict,
refit=self.refit_on_update,
**kwargs,
additional_model_inputs=additional_model_inputs,
)

@single_surrogate_only
Expand Down Expand Up @@ -536,7 +538,7 @@ def cross_validate(
metric_names: List[str],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
**kwargs: Any,
**additional_model_inputs: Any,
) -> Tuple[Tensor, Tensor]:
# Will fail if metric_names exist across multiple models
surrogate_labels = (
Expand Down Expand Up @@ -589,7 +591,7 @@ def cross_validate(
search_space_digest=search_space_digest,
state_dicts=state_dicts,
refit=self.refit_on_cv,
**kwargs,
**additional_model_inputs,
)
X_test_prediction = self.predict_from_surrogate(
surrogate_label=surrogate_label, X=X_test
Expand Down
Loading

0 comments on commit c67d597

Please sign in to comment.