diff --git a/ax/benchmark/benchmark_test_functions/synthetic.py b/ax/benchmark/benchmark_test_functions/synthetic.py new file mode 100644 index 00000000000..2ec86a5f299 --- /dev/null +++ b/ax/benchmark/benchmark_test_functions/synthetic.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field + +import torch +from ax.benchmark.benchmark_test_function import BenchmarkTestFunction + + +@dataclass(kw_only=True) +class IdentityTestFunction(BenchmarkTestFunction): + """ + Test function that returns the value of parameter "x0", ignoring any others. + """ + + outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"]) + n_steps: int = 1 + + # pyre-fixme[14]: Inconsistent override + def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor: + """ + Return params["x0"] for each outcome for each time step. + + Args: + params: A dictionary with key "x0". + """ + value = params["x0"] + return torch.full( + (len(self.outcome_names), self.n_steps), value, dtype=torch.float64 + ) diff --git a/ax/benchmark/problems/registry.py b/ax/benchmark/problems/registry.py index 6b0913e70e7..51f6c6e009b 100644 --- a/ax/benchmark/problems/registry.py +++ b/ax/benchmark/problems/registry.py @@ -16,6 +16,7 @@ get_pytorch_cnn_torchvision_benchmark_problem, ) from ax.benchmark.problems.runtime_funcs import int_from_params +from ax.benchmark.problems.synthetic.bandit import get_bandit_problem from ax.benchmark.problems.synthetic.discretized.mixed_integer import ( get_discrete_ackley, get_discrete_hartmann, @@ -55,6 +56,9 @@ class BenchmarkProblemRegistryEntry: "name": "ackley4_async_noisy", }, ), + "Bandit": BenchmarkProblemRegistryEntry( + factory_fn=get_bandit_problem, factory_kwargs={} + ), "branin": BenchmarkProblemRegistryEntry( factory_fn=create_problem_from_botorch, factory_kwargs={ diff --git a/ax/benchmark/problems/synthetic/bandit.py b/ax/benchmark/problems/synthetic/bandit.py new file mode 100644 index 00000000000..a12419fdc40 --- /dev/null +++ b/ax/benchmark/problems/synthetic/bandit.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from warnings import warn + +import numpy as np + +from ax.benchmark.benchmark_problem import BenchmarkProblem, get_soo_opt_config +from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction +from ax.core.parameter import ChoiceParameter, ParameterType +from ax.core.search_space import SearchSpace + + +def get_baseline(num_choices: int, n_sims: int = 100000000) -> float: + """ + Compute the baseline value. + + The baseline for this problem takes into account noise, because it uses the + inference trace, and the bandit structure, which allows for running all arms + in one noisy batch: + + Run a BatchTrial with every arm, with equal size. Choose the arm with the + best observed value and take its true value. Take the expectation of the + outcome of this process. + """ + noise_per_arm = num_choices**0.5 + sim_observed_effects = ( + np.random.normal(0, noise_per_arm, (n_sims, num_choices)) + + np.arange(num_choices)[None, :] + ) + identified_best_arm = sim_observed_effects.argmin(axis=1) + # because of the use of IdentityTestFunction + baseline = identified_best_arm.mean() + return baseline + + +def get_bandit_problem(num_choices: int = 30, num_trials: int = 3) -> BenchmarkProblem: + parameter = ChoiceParameter( + name="x0", + parameter_type=ParameterType.INT, + values=list(range(num_choices)), + is_ordered=False, + sort_values=False, + ) + search_space = SearchSpace(parameters=[parameter]) + test_function = IdentityTestFunction() + optimization_config = get_soo_opt_config( + outcome_names=test_function.outcome_names, observe_noise_sd=True + ) + baselines = { + 10: 1.40736478, + 30: 2.4716703, + 100: 4.403284, + } + if num_choices not in baselines: + warn( + f"Baseline value is not available for num_choices={num_choices}. Use " + "`get_baseline` to compute the baseline and add it to `baselines`." + ) + baseline_value = baselines[30] + else: + baseline_value = baselines[num_choices] + return BenchmarkProblem( + name="Bandit", + num_trials=num_trials, + search_space=search_space, + optimization_config=optimization_config, + optimal_value=0, + baseline_value=baseline_value, + test_function=test_function, + report_inference_value_as_trace=True, + noise_std=1.0, + status_quo_params={"x0": num_choices // 2}, + ) diff --git a/ax/benchmark/tests/problems/synthetic/test_bandit.py b/ax/benchmark/tests/problems/synthetic/test_bandit.py new file mode 100644 index 00000000000..cb49345470a --- /dev/null +++ b/ax/benchmark/tests/problems/synthetic/test_bandit.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.benchmark.problems.synthetic.bandit import get_bandit_problem, get_baseline +from ax.utils.common.testutils import TestCase + + +class TestProblems(TestCase): + def test_get_baseline(self) -> None: + num_choices = 5 + baseline = get_baseline(num_choices=num_choices, n_sims=100) + self.assertGreater(baseline, 0) + # Worst = num_choices - 1; random guessing = (num_choices - 1) / 2 + self.assertLess(baseline, (num_choices - 1) / 2) + + def test_get_bandit_problem(self) -> None: + problem = get_bandit_problem() + self.assertEqual(problem.name, "Bandit") + self.assertEqual(problem.num_trials, 3) + self.assertTrue(problem.report_inference_value_as_trace) + + problem = get_bandit_problem(num_choices=26, num_trials=4) + self.assertEqual(problem.num_trials, 4) + self.assertEqual(problem.status_quo_params, {"x0": 26 // 2}) + + def test_baseline_exception(self) -> None: + with self.assertWarnsRegex( + Warning, expected_regex="Baseline value is not available for num_choices=17" + ): + problem = get_bandit_problem(num_choices=17) + + self.assertEqual(problem.baseline_value, get_bandit_problem().baseline_value) diff --git a/ax/benchmark/tests/problems/test_problems.py b/ax/benchmark/tests/problems/test_problems.py index e43b2055d4e..9a7971ef2e0 100644 --- a/ax/benchmark/tests/problems/test_problems.py +++ b/ax/benchmark/tests/problems/test_problems.py @@ -24,6 +24,7 @@ def test_load_problems(self) -> None: def test_name(self) -> None: expected_names = [ + ("Bandit", "Bandit"), ("branin", "Branin"), ("hartmann3", "Hartmann_3d"), ("hartmann6", "Hartmann_6d"), diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 46972341026..29be755ec5a 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -31,6 +31,7 @@ ) from ax.benchmark.benchmark_result import BenchmarkResult from ax.benchmark.benchmark_runner import BenchmarkRunner +from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction from ax.benchmark.methods.modular_botorch import ( get_sobol_botorch_modular_acquisition, get_sobol_mbm_generation_strategy, @@ -61,7 +62,6 @@ get_multi_objective_benchmark_problem, get_single_objective_benchmark_problem, get_soo_surrogate, - IdentityTestFunction, TestDataset, ) diff --git a/ax/benchmark/tests/test_benchmark_runner.py b/ax/benchmark/tests/test_benchmark_runner.py index 0fa9c9a6251..6d2582ee867 100644 --- a/ax/benchmark/tests/test_benchmark_runner.py +++ b/ax/benchmark/tests/test_benchmark_runner.py @@ -18,6 +18,8 @@ from ax.benchmark.benchmark_runner import _add_noise, BenchmarkRunner from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction + +from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction from ax.benchmark.problems.synthetic.hss.jenatton import ( get_jenatton_benchmark_problem, Jenatton, @@ -35,8 +37,8 @@ DummyTestFunction, get_jenatton_trials, get_soo_surrogate_test_function, - IdentityTestFunction, ) + from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann from botorch.utils.transforms import normalize from pandas import DataFrame diff --git a/ax/modelbridge/tests/test_discrete_modelbridge.py b/ax/modelbridge/tests/test_discrete_modelbridge.py index d89db4a504d..bb8b5250c5a 100644 --- a/ax/modelbridge/tests/test_discrete_modelbridge.py +++ b/ax/modelbridge/tests/test_discrete_modelbridge.py @@ -276,8 +276,6 @@ def test_cross_validate(self, mock_init: Mock) -> None: def test_get_parameter_values(self) -> None: parameter_values = _get_parameter_values(self.search_space, ["x", "y", "z"]) self.assertEqual(parameter_values, [[0.0, 1.0], ["foo", "bar"], [True]]) - # pyre-fixme[6]: For 1st param expected `List[Parameter]` but got - # `List[Union[ChoiceParameter, FixedParameter]]`. search_space = SearchSpace(self.parameters) search_space._parameters["x"] = RangeParameter( "x", ParameterType.FLOAT, 0.1, 0.4 diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index cd42ae70b24..e50dcfa0e3c 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from typing import Any, Iterator @@ -28,6 +27,7 @@ from ax.benchmark.benchmark_step_runtime_function import TBenchmarkStepRuntimeFunction from ax.benchmark.benchmark_test_function import BenchmarkTestFunction from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction +from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_search_space from ax.core.arm import Arm from ax.core.batch_trial import BatchTrial @@ -301,23 +301,6 @@ def get_next_candidate( return {self.param_name: next(self.iterator)} -@dataclass(kw_only=True) -class IdentityTestFunction(BenchmarkTestFunction): - outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"]) - n_steps: int = 1 - - # pyre-fixme[14]: Inconsistent override - def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor: - """ - Args: - params: A dictionary with key "x0". - """ - value = params["x0"] - return torch.full( - (len(self.outcome_names), self.n_steps), value, dtype=torch.float64 - ) - - def get_discrete_search_space(n_values: int = 20) -> SearchSpace: return SearchSpace( parameters=[ diff --git a/sphinx/source/benchmark.rst b/sphinx/source/benchmark.rst index 2ad58067b9d..f8f707f79af 100644 --- a/sphinx/source/benchmark.rst +++ b/sphinx/source/benchmark.rst @@ -106,6 +106,14 @@ Benchmark Problems Registry :undoc-members: :show-inheritance: +Benchmark Problems: Bandit +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.benchmark.problems.synthetic.bandit + :members: + :undoc-members: + :show-inheritance: + Benchmark Problems High Dimensional Embedding ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -146,6 +154,7 @@ Benchmark Problems PyTorchCNN TorchVision :undoc-members: :show-inheritance: + Benchmark Problems Runtime Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -169,3 +178,11 @@ Benchmark Test Functions: Surrogate :members: :undoc-members: :show-inheritance: + +Benchmark Test Functions: Synthetic +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.benchmark.benchmark_test_functions.synthetic + :members: + :undoc-members: + :show-inheritance: