From 9b4bf9709e3dfc3b8e53b87582fa159ad273a6b4 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 23 Oct 2023 16:22:12 -0700 Subject: [PATCH] Fix & deprecate REMBO (#1926) Summary: Un-mocking `optimize_acqf` reproduces the failures reported in https://github.com/facebook/Ax/issues/1924 Changed where the roundtrip transforms are applied to fix the failure. Also removed some mocks from ALEBO tests. ~~It likely has the same issue but it doesn't have e2e gen tests that'd produce them.~~ It doesn't have the same issue since it has a custom acqf optimizer. Differential Revision: D50572494 Pulled By: saitcakmak --- ax/modelbridge/tests/test_rembo_strategy.py | 24 +++---------- ax/models/tests/test_alebo.py | 39 +++++++-------------- ax/models/tests/test_rembo.py | 19 +++------- ax/models/torch/rembo.py | 26 +++++++++----- 4 files changed, 40 insertions(+), 68 deletions(-) diff --git a/ax/modelbridge/tests/test_rembo_strategy.py b/ax/modelbridge/tests/test_rembo_strategy.py index 74f608671ad..02ae35e6b99 100644 --- a/ax/modelbridge/tests/test_rembo_strategy.py +++ b/ax/modelbridge/tests/test_rembo_strategy.py @@ -4,8 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from unittest.mock import patch - import torch from ax.core.experiment import Experiment from ax.core.objective import Objective @@ -18,21 +16,12 @@ from ax.modelbridge.strategies.rembo import HeSBOStrategy, REMBOStrategy from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import fast_botorch_optimize class REMBOStrategyTest(TestCase): - @patch( - "ax.models.torch.botorch_defaults.optimize_acqf", - autospec=True, - return_value=( - torch.randn((2, 6), dtype=torch.double), - torch.randn((2, 6), dtype=torch.double), - ), - ) - @patch("ax.models.torch.botorch_defaults.fit_gpytorch_mll", autospec=True) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def test_REMBOStrategy(self, mock_fit_gpytorch_mll, mock_optimize_acqf): + @fast_botorch_optimize + def test_REMBOStrategy(self) -> None: # Construct a high-D test experiment with multiple metrics hartmann_search_space = SearchSpace( parameters=[ @@ -88,7 +77,6 @@ def test_REMBOStrategy(self, mock_fit_gpytorch_mll, mock_optimize_acqf): # Iterate until the first projection fits a GP for _ in range(4): exp.new_batch_trial(generator_run=gs.gen(experiment=exp, n=2)).run() - mock_fit_gpytorch_mll.assert_not_called() self.assertEqual(len(gs.arms_by_proj[0]), 4) self.assertEqual(len(gs.arms_by_proj[1]), 4) @@ -104,12 +92,8 @@ def test_REMBOStrategy(self, mock_fit_gpytorch_mll, mock_optimize_acqf): self.assertLess(len(gs.arms_by_proj[3]), 4) exp.new_batch_trial(generator_run=gs.gen(experiment=exp, n=2)).run() - if i < 2: - mock_fit_gpytorch_mll.assert_not_called() - else: - # After all proj. have > 4 arms' worth of data, GP can be fit. + if i >= 2: self.assertFalse(any(len(x) < 4 for x in gs.arms_by_proj.values())) - mock_fit_gpytorch_mll.assert_called() self.assertTrue(len(gs.model_transitions) > 0) gs2 = gs.clone_reset() diff --git a/ax/models/tests/test_alebo.py b/ax/models/tests/test_alebo.py index 6c87764e577..b122c93a27c 100644 --- a/ax/models/tests/test_alebo.py +++ b/ax/models/tests/test_alebo.py @@ -341,38 +341,25 @@ def test_ALEBO(self) -> None: ) # Test gen - # With clipping - with mock.patch( - "ax.models.torch.alebo.optimize_acqf", - autospec=True, - return_value=(m.Xs[0], torch.tensor([])), - ): - gen_results = m.gen( - n=1, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) + gen_results = m.gen( + n=1, + search_space_digest=search_space_digest, + torch_opt_config=torch_opt_config, + ) self.assertFalse(torch.allclose(gen_results.points, train_X)) self.assertTrue(gen_results.points.min() >= -1) self.assertTrue(gen_results.points.max() <= 1) # Without - with mock.patch( - "ax.models.torch.alebo.optimize_acqf", - autospec=True, - return_value=(torch.ones(1, 2, dtype=torch.double), torch.tensor([])), - ): - gen_results = m.gen( - n=1, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) + gen_results = m.gen( + n=1, + search_space_digest=search_space_digest, + torch_opt_config=torch_opt_config, + ) - self.assertTrue( - torch.allclose( - gen_results.points, - torch.tensor([[-0.2, -0.1, 0.0, 0.1, 0.2]], dtype=torch.double), - ) + self.assertEqual( + gen_results.points.shape, + torch.Size([1, 5]), ) # Test get_and_fit with single metric diff --git a/ax/models/tests/test_rembo.py b/ax/models/tests/test_rembo.py index 1f8ffb3b1d0..313b42826b3 100644 --- a/ax/models/tests/test_rembo.py +++ b/ax/models/tests/test_rembo.py @@ -4,8 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from unittest import mock - import torch from ax.core.search_space import SearchSpaceDigest from ax.models.torch.rembo import REMBO @@ -115,18 +113,11 @@ def test_REMBOModel(self) -> None: self.assertEqual(f.shape, torch.Size([1, 2])) # Test gen - Xgen_d = torch.tensor([[0.4, 0.8], [-0.2, 1.0]]) - acqfv_dummy = torch.tensor([[0.0, 0.0], [0.0, 0.0]]) - with mock.patch( - "ax.models.torch.botorch_defaults.optimize_acqf", - autospec=True, - return_value=(Xgen_d, acqfv_dummy), - ): - gen_results = m.gen( - n=2, - search_space_digest=search_space_digest, - torch_opt_config=torch_opt_config, - ) + gen_results = m.gen( + n=2, + search_space_digest=search_space_digest, + torch_opt_config=torch_opt_config, + ) self.assertEqual(gen_results.points.shape[1], 4) self.assertEqual(len(m.X_d), 5) diff --git a/ax/models/torch/rembo.py b/ax/models/torch/rembo.py index ab6f8f34fd6..2a470b05555 100644 --- a/ax/models/torch/rembo.py +++ b/ax/models/torch/rembo.py @@ -6,6 +6,7 @@ import dataclasses from typing import Any, List, Optional, Tuple +from warnings import warn import torch from ax.core.search_space import SearchSpaceDigest @@ -44,14 +45,18 @@ def __init__( bounds_d: List[Tuple[float, float]], **kwargs: Any, ) -> None: + warn( + "REMBO is deprecated and does not guarantee correctness. " + "It will be removed in Ax 0.5.0.", + DeprecationWarning, + ) self.A = A - # pyre-fixme[4]: Attribute must be annotated. - self._pinvA = torch.pinverse(A) # compute pseudo inverse once and cache it + # compute pseudo inverse once and cache it + self._pinvA: Tensor = torch.pinverse(A) # Projected points in low-d space generated in the optimization - # pyre-fixme[4]: Attribute must be annotated. - self.X_d = list(initial_X_d) - # pyre-fixme[4]: Attribute must be annotated. - self.X_d_gen = [] # Projected points that were generated by this model + self.X_d: List[Tensor] = list(initial_X_d) + # Projected points that were generated by this model + self.X_d_gen: List[Tensor] = [] self.bounds_d = bounds_d self.num_outputs = 0 super().__init__(**kwargs) @@ -180,19 +185,24 @@ def gen( assert torch_opt_config.fixed_features is None assert torch_opt_config.pending_observations is None # Do gen in the low-dimensional space and project up + rounding_func = torch_opt_config.rounding_func gen_results = super().gen( n=n, search_space_digest=dataclasses.replace( search_space_digest, bounds=[(0.0, 1.0)] * len(self.bounds_d), ), - torch_opt_config=torch_opt_config, + torch_opt_config=dataclasses.replace(torch_opt_config, rounding_func=None), ) Xopt = self.from_01(gen_results.points) self.X_d.extend([x.clone() for x in Xopt]) self.X_d_gen.extend([x.clone() for x in Xopt]) + gen_points = self.project_up(Xopt) + if rounding_func is not None: + for i in range(len(gen_points)): + gen_points[i] = rounding_func(gen_points[i]) return TorchGenResults( - points=self.project_up(Xopt), + points=gen_points, weights=gen_results.weights, )