diff --git a/ax/models/torch/botorch_modular/optimizer_argparse.py b/ax/models/torch/botorch_modular/optimizer_argparse.py index dde926cb2d5..e4351e032bc 100644 --- a/ax/models/torch/botorch_modular/optimizer_argparse.py +++ b/ax/models/torch/botorch_modular/optimizer_argparse.py @@ -8,16 +8,11 @@ from __future__ import annotations -from typing import Any, TypeVar, Union +from typing import Any from ax.exceptions.core import UnsupportedError -from ax.utils.common.typeutils import _argparse_type_encoder from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.knowledge_gradient import qKnowledgeGradient -from botorch.utils.dispatcher import Dispatcher - -T = TypeVar("T") -MaybeType = Union[T, type[T]] # Annotation for a type or instance thereof # Acquisition defaults NUM_RESTARTS = 20 @@ -26,14 +21,8 @@ BATCH_LIMIT = 5 -optimizer_argparse = Dispatcher( - name="optimizer_argparse", encoder=_argparse_type_encoder -) - - -@optimizer_argparse.register(AcquisitionFunction) -def _argparse_base( - acqf: MaybeType[AcquisitionFunction], +def optimizer_argparse( + acqf: AcquisitionFunction, *, optimizer: str, sequential: bool = True, @@ -102,6 +91,15 @@ def _argparse_base( f"optimizer=`{optimizer}` is not supported. Accepted options are " f"{supported_optimizers}" ) + if (optimizer != "optimize_acqf") and isinstance(acqf, qKnowledgeGradient): + raise RuntimeError( + "Ax is attempting to use a discrete or mixed optimizer, " + f"`{optimizer}`, but this is not compatible with " + "`qKnowledgeGradient` or its subclasses. To address this, please " + "either use a different acquisition class or make parameters " + "continuous using the transform " + "`ax.modelbridge.registry.Cont_X_trans`." + ) provided_options = optimizer_options if optimizer_options is not None else {} # Construct arguments from options that are not `provided_options`. @@ -138,41 +136,3 @@ def _argparse_base( options.update(**{k: v for k, v in provided_options.items() if k != "options"}) return options - - -@optimizer_argparse.register(qKnowledgeGradient) -def _argparse_kg( - acqf: qKnowledgeGradient, - *, - optimizer: str = "optimize_acqf", - sequential: bool = True, - num_restarts: int = NUM_RESTARTS, - raw_samples: int = RAW_SAMPLES, - init_batch_limit: int = INIT_BATCH_LIMIT, - batch_limit: int = BATCH_LIMIT, - optimizer_options: dict[str, Any] | None = None, - **ignore: Any, -) -> dict[str, Any]: - """ - Argument constructor for optimization with qKG, differing from the - base case in that it errors if the optimizer is not `optimize_acqf`. - """ - if optimizer != "optimize_acqf": - raise RuntimeError( - "Ax is attempting to use a discrete or mixed optimizer, " - f"`{optimizer}`, but this is not compatible with " - "`qKnowledgeGradient` or its subclasses. To address this, please " - "either use a different acquisition class or make parameters " - "continuous using the transform " - "`ax.modelbridge.registry.Cont_X_trans`." - ) - return _argparse_base( - acqf=acqf, - optimizer="optimize_acqf", - sequential=sequential, - num_restarts=num_restarts, - raw_samples=raw_samples, - init_batch_limit=init_batch_limit, - batch_limit=batch_limit, - optimizer_options=optimizer_options, - ) diff --git a/ax/models/torch/tests/test_optimizer_argparse.py b/ax/models/torch/tests/test_optimizer_argparse.py index b48aa8c31e4..f3727ae34a0 100644 --- a/ax/models/torch/tests/test_optimizer_argparse.py +++ b/ax/models/torch/tests/test_optimizer_argparse.py @@ -8,22 +8,18 @@ from __future__ import annotations -from itertools import product -from unittest.mock import patch +from unittest.mock import MagicMock from ax.exceptions.core import UnsupportedError from ax.models.torch.botorch_modular.optimizer_argparse import ( - _argparse_base, BATCH_LIMIT, INIT_BATCH_LIMIT, - MaybeType, NUM_RESTARTS, optimizer_argparse, RAW_SAMPLES, ) from ax.utils.common.testutils import TestCase from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import LogExpectedImprovement from botorch.acquisition.knowledge_gradient import ( qKnowledgeGradient, qMultiFidelityKnowledgeGradient, @@ -31,12 +27,19 @@ class DummyAcquisitionFunction(AcquisitionFunction): - pass + def __init__(self) -> None: + return + + # pyre-fixme[14]: Inconsistent override + # pyre-fixme[15]: Inconsistent override + def forward(self) -> int: + return 0 class OptimizerArgparseTest(TestCase): def setUp(self) -> None: super().setUp() + self.acqf = DummyAcquisitionFunction() self.default_expected_options = { "optimize_acqf": { "num_restarts": NUM_RESTARTS, @@ -70,57 +73,24 @@ def setUp(self) -> None: }, } - def test_notImplemented(self) -> None: - with self.assertRaisesRegex( - NotImplementedError, "Could not find signature for" - ): - optimizer_argparse[type(None)] # passing `None` produces a different error - def test_unsupported_optimizer(self) -> None: with self.assertRaisesRegex( ValueError, "optimizer=`wishful thinking` is not supported" ): - optimizer_argparse(LogExpectedImprovement, optimizer="wishful thinking") - - def test_register(self) -> None: - with patch.dict(optimizer_argparse.funcs, {}): - - @optimizer_argparse.register(DummyAcquisitionFunction) - def _argparse(acqf: MaybeType[DummyAcquisitionFunction]) -> None: - pass - - self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse) - - def test_fallback(self) -> None: - with patch.dict(optimizer_argparse.funcs, {}): - - @optimizer_argparse.register(AcquisitionFunction) - def _argparse(acqf: MaybeType[DummyAcquisitionFunction]) -> None: - pass - - self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse) + optimizer_argparse(self.acqf, optimizer="wishful thinking") def test_optimizer_options(self) -> None: - # qKG should have a bespoke test # currently there is only one function in fns_to_test - fns_to_test = [ - elt - for elt in optimizer_argparse.funcs.values() - if elt is not optimizer_argparse[qKnowledgeGradient] - ] user_options = {"foo": "bar", "num_restarts": 13} - for func, optimizer in product( - fns_to_test, - [ - "optimize_acqf", - "optimize_acqf_discrete", - "optimize_acqf_mixed", - "optimize_acqf_discrete_local_search", - ], - ): - with self.subTest(func=func, optimizer=optimizer): - parsed_options = func( - None, optimizer_options=user_options, optimizer=optimizer + for optimizer in [ + "optimize_acqf", + "optimize_acqf_discrete", + "optimize_acqf_mixed", + "optimize_acqf_discrete_local_search", + ]: + with self.subTest(optimizer=optimizer): + parsed_options = optimizer_argparse( + self.acqf, optimizer_options=user_options, optimizer=optimizer ) self.assertDictEqual( {**self.default_expected_options[optimizer], **user_options}, @@ -130,66 +100,69 @@ def test_optimizer_options(self) -> None: # Also test sub-options. inner_options = {"batch_limit": 10, "maxiter": 20} options = {"options": inner_options} - for func in fns_to_test: - for optimizer in [ - "optimize_acqf", - "optimize_acqf_mixed", - "optimize_acqf_mixed_alternating", - ]: - default = self.default_expected_options[optimizer] - parsed_options = func( - None, optimizer_options=options, optimizer=optimizer + for optimizer in [ + "optimize_acqf", + "optimize_acqf_mixed", + "optimize_acqf_mixed_alternating", + ]: + default = self.default_expected_options[optimizer] + parsed_options = optimizer_argparse( + self.acqf, optimizer_options=options, optimizer=optimizer + ) + expected_options = {k: v for k, v in default.items() if k != "options"} + if "options" in default: + expected_options["options"] = { + **default["options"], + **inner_options, + } + else: + expected_options["options"] = inner_options + self.assertDictEqual(expected_options, parsed_options) + + # Error out if options is specified for an optimizer that does + # not support the arg. + for optimizer in [ + "optimize_acqf_discrete", + "optimize_acqf_discrete_local_search", + ]: + with self.assertRaisesRegex(UnsupportedError, "`options` argument"): + optimizer_argparse( + self.acqf, + optimizer_options={"options": {"batch_limit": 10, "maxiter": 20}}, + optimizer=optimizer, ) - expected_options = {k: v for k, v in default.items() if k != "options"} - if "options" in default: - expected_options["options"] = { - **default["options"], - **inner_options, - } - else: - expected_options["options"] = inner_options - self.assertDictEqual(expected_options, parsed_options) - - # Error out if options is specified for an optimizer that does - # not support the arg. - for optimizer in [ - "optimize_acqf_discrete", - "optimize_acqf_discrete_local_search", - ]: - with self.assertRaisesRegex(UnsupportedError, "`options` argument"): - func( - None, - optimizer_options={ - "options": {"batch_limit": 10, "maxiter": 20} - }, - optimizer=optimizer, - ) - - # `sequential=False` with optimizers other than `optimize_acqf`. - for optimizer in [ - "optimize_acqf_homotopy", - "optimize_acqf_mixed", - "optimize_acqf_mixed_alternating", - "optimize_acqf_discrete", - "optimize_acqf_discrete_local_search", - ]: - with self.assertRaisesRegex( - UnsupportedError, "does not support `sequential=False`" - ): - func(None, sequential=False, optimizer=optimizer) + + # `sequential=False` with optimizers other than `optimize_acqf`. + for optimizer in [ + "optimize_acqf_homotopy", + "optimize_acqf_mixed", + "optimize_acqf_mixed_alternating", + "optimize_acqf_discrete", + "optimize_acqf_discrete_local_search", + ]: + with self.assertRaisesRegex( + UnsupportedError, "does not support `sequential=False`" + ): + optimizer_argparse(self.acqf, sequential=False, optimizer=optimizer) def test_kg(self) -> None: user_options = {"foo": "bar", "num_restarts": 114} - generic_options = _argparse_base( - None, optimizer_options=user_options, optimizer="optimize_acqf" + generic_options = optimizer_argparse( + self.acqf, optimizer_options=user_options, optimizer="optimize_acqf" ) - for acqf in (qKnowledgeGradient, qMultiFidelityKnowledgeGradient): + for acqf in ( + qKnowledgeGradient(model=MagicMock(), posterior_transform=MagicMock()), + qMultiFidelityKnowledgeGradient( + model=MagicMock(), posterior_transform=MagicMock() + ), + ): with self.subTest(acqf=acqf): options = optimizer_argparse( acqf, q=None, bounds=None, optimizer_options=user_options, + optimizer="optimize_acqf", ) self.assertEqual(options, generic_options)