From f83d9a627ce42376d28827c4777c5ceaa7af644e Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Thu, 19 Dec 2024 08:36:38 -0800 Subject: [PATCH] Make `Acquisition.optimize` work with discrete optimizer regardless of whether `num_restarts` is in `optimizer_options` (#3197) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3197 Context: In `Acquisition.optimize`, we remove `raw_samples` from `optimizer_options` when the discrete optimizer was used (D63035021), but we don't do the same for `num_restarts` even though it is also not supported by the discrete optimizer. This PR: * Removes `num_restarts` from `optimizer_options` when the discrete optimizer is ued * Takes out a TODO to ensure it is never passed in the first place-- not feasible since we won't know if the optimizer is discrete before hitting that point Reviewed By: saitcakmak Differential Revision: D67419600 fbshipit-source-id: df6f4beb06a3bc678337ac563a77d6f2146a25ba --- ax/models/torch/botorch_modular/acquisition.py | 6 +++--- ax/models/torch/tests/test_acquisition.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index cc21ceb7163..49b1f41cb5a 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -309,11 +309,11 @@ def optimize( optimizer = "optimize_acqf_discrete_local_search" else: optimizer = "optimize_acqf_discrete" - # `raw_samples` is not supported by `optimize_acqf_discrete`. - # TODO[santorella]: Rather than manually removing it, we should - # ensure that it is never passed. + # `raw_samples` and `num_restarts` are not supported by + # `optimize_acqf_discrete`. if optimizer_options is not None: optimizer_options.pop("raw_samples", None) + optimizer_options.pop("num_restarts", None) else: n_combos = math.prod([len(v) for v in discrete_choices.values()]) # If there are diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index d0805591e15..b0132dc775d 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -383,8 +383,8 @@ def test_optimize_discrete(self) -> None: n = 2 # Also check that it runs when optimizer options are provided, whether - # `raw_samples` are present or not. - for optimizer_options in [None, {"raw_samples": 8}, {}]: + # `raw_samples` or `num_restarts` is present or not. + for optimizer_options in [None, {"raw_samples": 8}, {"num_restarts": 8}]: with self.subTest(optimizer_options=optimizer_options): acquisition.optimize( n=n,