Skip to content

Commit

Permalink
Fix & deprecate REMBO (#1926)
Browse files Browse the repository at this point in the history
Summary:
Un-mocking `optimize_acqf` reproduces the failures reported in #1924

Changed where the roundtrip transforms are applied to fix the failure.

Also removed some mocks from ALEBO tests. ~~It likely has the same issue but it doesn't have e2e gen tests that'd produce them.~~ It doesn't have the same issue since it has a custom acqf optimizer.


Differential Revision: D50572494

Pulled By: saitcakmak
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 23, 2023
1 parent 421dd10 commit 9b4bf97
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 68 deletions.
24 changes: 4 additions & 20 deletions ax/modelbridge/tests/test_rembo_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import torch
from ax.core.experiment import Experiment
from ax.core.objective import Objective
Expand All @@ -18,21 +16,12 @@
from ax.modelbridge.strategies.rembo import HeSBOStrategy, REMBOStrategy
from ax.runners.synthetic import SyntheticRunner
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import fast_botorch_optimize


class REMBOStrategyTest(TestCase):
@patch(
"ax.models.torch.botorch_defaults.optimize_acqf",
autospec=True,
return_value=(
torch.randn((2, 6), dtype=torch.double),
torch.randn((2, 6), dtype=torch.double),
),
)
@patch("ax.models.torch.botorch_defaults.fit_gpytorch_mll", autospec=True)
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def test_REMBOStrategy(self, mock_fit_gpytorch_mll, mock_optimize_acqf):
@fast_botorch_optimize
def test_REMBOStrategy(self) -> None:
# Construct a high-D test experiment with multiple metrics
hartmann_search_space = SearchSpace(
parameters=[
Expand Down Expand Up @@ -88,7 +77,6 @@ def test_REMBOStrategy(self, mock_fit_gpytorch_mll, mock_optimize_acqf):
# Iterate until the first projection fits a GP
for _ in range(4):
exp.new_batch_trial(generator_run=gs.gen(experiment=exp, n=2)).run()
mock_fit_gpytorch_mll.assert_not_called()

self.assertEqual(len(gs.arms_by_proj[0]), 4)
self.assertEqual(len(gs.arms_by_proj[1]), 4)
Expand All @@ -104,12 +92,8 @@ def test_REMBOStrategy(self, mock_fit_gpytorch_mll, mock_optimize_acqf):
self.assertLess(len(gs.arms_by_proj[3]), 4)

exp.new_batch_trial(generator_run=gs.gen(experiment=exp, n=2)).run()
if i < 2:
mock_fit_gpytorch_mll.assert_not_called()
else:
# After all proj. have > 4 arms' worth of data, GP can be fit.
if i >= 2:
self.assertFalse(any(len(x) < 4 for x in gs.arms_by_proj.values()))
mock_fit_gpytorch_mll.assert_called()

self.assertTrue(len(gs.model_transitions) > 0)
gs2 = gs.clone_reset()
Expand Down
39 changes: 13 additions & 26 deletions ax/models/tests/test_alebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,38 +341,25 @@ def test_ALEBO(self) -> None:
)

# Test gen
# With clipping
with mock.patch(
"ax.models.torch.alebo.optimize_acqf",
autospec=True,
return_value=(m.Xs[0], torch.tensor([])),
):
gen_results = m.gen(
n=1,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)
gen_results = m.gen(
n=1,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)

self.assertFalse(torch.allclose(gen_results.points, train_X))
self.assertTrue(gen_results.points.min() >= -1)
self.assertTrue(gen_results.points.max() <= 1)
# Without
with mock.patch(
"ax.models.torch.alebo.optimize_acqf",
autospec=True,
return_value=(torch.ones(1, 2, dtype=torch.double), torch.tensor([])),
):
gen_results = m.gen(
n=1,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)
gen_results = m.gen(
n=1,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)

self.assertTrue(
torch.allclose(
gen_results.points,
torch.tensor([[-0.2, -0.1, 0.0, 0.1, 0.2]], dtype=torch.double),
)
self.assertEqual(
gen_results.points.shape,
torch.Size([1, 5]),
)

# Test get_and_fit with single metric
Expand Down
19 changes: 5 additions & 14 deletions ax/models/tests/test_rembo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.models.torch.rembo import REMBO
Expand Down Expand Up @@ -115,18 +113,11 @@ def test_REMBOModel(self) -> None:
self.assertEqual(f.shape, torch.Size([1, 2]))

# Test gen
Xgen_d = torch.tensor([[0.4, 0.8], [-0.2, 1.0]])
acqfv_dummy = torch.tensor([[0.0, 0.0], [0.0, 0.0]])
with mock.patch(
"ax.models.torch.botorch_defaults.optimize_acqf",
autospec=True,
return_value=(Xgen_d, acqfv_dummy),
):
gen_results = m.gen(
n=2,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)
gen_results = m.gen(
n=2,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
)

self.assertEqual(gen_results.points.shape[1], 4)
self.assertEqual(len(m.X_d), 5)
26 changes: 18 additions & 8 deletions ax/models/torch/rembo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dataclasses
from typing import Any, List, Optional, Tuple
from warnings import warn

import torch
from ax.core.search_space import SearchSpaceDigest
Expand Down Expand Up @@ -44,14 +45,18 @@ def __init__(
bounds_d: List[Tuple[float, float]],
**kwargs: Any,
) -> None:
warn(
"REMBO is deprecated and does not guarantee correctness. "
"It will be removed in Ax 0.5.0.",
DeprecationWarning,
)
self.A = A
# pyre-fixme[4]: Attribute must be annotated.
self._pinvA = torch.pinverse(A) # compute pseudo inverse once and cache it
# compute pseudo inverse once and cache it
self._pinvA: Tensor = torch.pinverse(A)
# Projected points in low-d space generated in the optimization
# pyre-fixme[4]: Attribute must be annotated.
self.X_d = list(initial_X_d)
# pyre-fixme[4]: Attribute must be annotated.
self.X_d_gen = [] # Projected points that were generated by this model
self.X_d: List[Tensor] = list(initial_X_d)
# Projected points that were generated by this model
self.X_d_gen: List[Tensor] = []
self.bounds_d = bounds_d
self.num_outputs = 0
super().__init__(**kwargs)
Expand Down Expand Up @@ -180,19 +185,24 @@ def gen(
assert torch_opt_config.fixed_features is None
assert torch_opt_config.pending_observations is None
# Do gen in the low-dimensional space and project up
rounding_func = torch_opt_config.rounding_func
gen_results = super().gen(
n=n,
search_space_digest=dataclasses.replace(
search_space_digest,
bounds=[(0.0, 1.0)] * len(self.bounds_d),
),
torch_opt_config=torch_opt_config,
torch_opt_config=dataclasses.replace(torch_opt_config, rounding_func=None),
)
Xopt = self.from_01(gen_results.points)
self.X_d.extend([x.clone() for x in Xopt])
self.X_d_gen.extend([x.clone() for x in Xopt])
gen_points = self.project_up(Xopt)
if rounding_func is not None:
for i in range(len(gen_points)):
gen_points[i] = rounding_func(gen_points[i])
return TorchGenResults(
points=self.project_up(Xopt),
points=gen_points,
weights=gen_results.weights,
)

Expand Down

0 comments on commit 9b4bf97

Please sign in to comment.