diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index 063b1d3db01..9fdc72e9cc3 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -148,14 +148,11 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False): n = 3 X_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], dtype=dtype, device=device) - acq_dummy = torch.tensor(0.0, dtype=dtype, device=device) model_gen_options = {} # test sequential optimize with mock.patch( - "ax.models.torch.botorch_defaults.optimize_acqf", - return_value=(X_dummy, acq_dummy), + "ax.models.torch.botorch_defaults.sequential_optimize", return_value=X_dummy ) as mock_optimize_acqf: - Xgen, wgen = model.gen( n=n, bounds=bounds, @@ -173,8 +170,7 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False): # test joint optimize with mock.patch( - "ax.models.torch.botorch_defaults.optimize_acqf", - return_value=(X_dummy, acq_dummy), + "ax.models.torch.botorch_defaults.joint_optimize", return_value=X_dummy ) as mock_optimize_acqf: Xgen, wgen = model.gen( n=n, diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 6b317a48b80..80c9b781965 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -55,7 +55,7 @@ Optional[Callable[[Tensor], Tensor]], Any, ], - Tuple[Tensor, Tensor], + Tensor, ] @@ -146,7 +146,7 @@ class BotorchModel(TorchModel): fixed_features, rounding_func, **kwargs, - ) -> (candidates, acq_values) + ) -> candidates Here `acq_function` is a BoTorch `AcquisitionFunction`, `bounds` is a tensor containing bounds on the parameters, `n` is the number of @@ -154,8 +154,7 @@ class BotorchModel(TorchModel): constraints on parameter values, `fixed_features` specifies features that should be fixed during generation, and `rounding_func` is a callback that rounds an optimization result appropriately. `candidates` is - a tensor of generated candidates, and `acq_values` are the acquisition - values associated with the candidates. For additional details on the + a tensor of generated candidates. For additional details on the arguments, see `scipy_optimizer`. """ @@ -316,7 +315,7 @@ def gen( botorch_rounding_func = get_rounding_func(rounding_func) - candidates, _ = self.acqf_optimizer( # pyre-ignore: [28] + candidates = self.acqf_optimizer( # pyre-ignore: [28] acq_function=checked_cast(AcquisitionFunction, acquisition_function), bounds=bounds_, n=n, diff --git a/ax/models/torch/botorch_defaults.py b/ax/models/torch/botorch_defaults.py index 6606f5005a8..1c5a8247ef6 100644 --- a/ax/models/torch/botorch_defaults.py +++ b/ax/models/torch/botorch_defaults.py @@ -18,7 +18,7 @@ from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP -from botorch.optim.optimize import optimize_acqf +from botorch.optim.optimize import joint_optimize, sequential_optimize from botorch.utils import ( get_objective_weights_transform, get_outcome_constraint_transforms, @@ -204,7 +204,7 @@ def scipy_optimizer( fixed_features: Optional[Dict[int, float]] = None, rounding_func: Optional[Callable[[Tensor], Tensor]] = None, **kwargs: Any, -) -> Tuple[Tensor, Tensor]: +) -> Tensor: r"""Optimizer using scipy's minimize module on a numpy-adpator. Args: @@ -233,12 +233,15 @@ def scipy_optimizer( num_restarts: int = kwargs.get("num_restarts", 20) raw_samples: int = kwargs.get("num_raw_samples", 50 * num_restarts) - sequential = not kwargs.get("joint_optimization", False) - # use SLSQP by default for small problems since it yields faster wall times - if sequential and "method" not in kwargs: - kwargs["method"] = "SLSQP" + if kwargs.get("joint_optimization", False): + optimize = joint_optimize + else: + optimize = sequential_optimize + # use SLSQP by default for small problems since it yields faster wall times + if "method" not in kwargs: + kwargs["method"] = "SLSQP" - return optimize_acqf( + X = optimize( acq_function=acq_function, bounds=bounds, q=n, @@ -248,8 +251,11 @@ def scipy_optimizer( inequality_constraints=inequality_constraints, fixed_features=fixed_features, post_processing_func=rounding_func, - sequential=not kwargs.get("joint_optimization", False), ) + # TODO: Un-hack this once botorch #234 is part of a stable release + if isinstance(X, tuple): + X, _ = X + return X def _get_model(