Skip to content

Commit

Permalink
Implement get_best_arm, get_pareto_frontier (#3142)
Browse files Browse the repository at this point in the history
Summary:

Implement "best point" functionality in Client. A few notes:
* Renamed get_best_trial --> get_best_arm: this is more accurate to what it is actually doing (calculating the best in sample point) and renaming will allow us to continue using this method when we create BatchClient
* If GS is not on a predictive step we return {} for the prediction term in both get_best_arm and get_pareto_frontier as well as log an error

Reviewed By: lena-kashtelyan

Differential Revision: D66702545
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 18, 2024
1 parent ceb07f5 commit 7f9b632
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 16 deletions.
103 changes: 89 additions & 14 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ax.core.trial import Trial
from ax.core.utils import get_pending_observation_features_based_on_trial_status
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.preview.api.configs import (
DatabaseConfig,
Expand All @@ -42,6 +42,7 @@
)
from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.utils.common.logger import get_logger
from ax.utils.common.random import with_rng_seed
from pyre_extensions import assert_is_instance, none_throws
Expand Down Expand Up @@ -657,29 +658,103 @@ def compute_analyses(

return cards

def get_best_trial(
def get_best_parameterization(
self, use_model_predictions: bool = True
) -> tuple[int, TParameterization, TOutcome]:
) -> tuple[TParameterization, TOutcome, int, str]:
"""
Calculates the best in-sample trial.
Identifies the best parameterization tried in the experiment so far, also
called the best in-sample arm.
If `use_model_predictions` is True, first attempts to do so with the model used
in optimization and its corresponding predictions if available. If
`use_model_predictions` is False or attempts to use the model fails, falls back
to the best raw objective based on the data fetched from the experiment.
Parameterizations which were observed to violate outcome constraints are not
eligible to be the best parameterization.
Returns:
- The index of the best trial
- The parameters of the best trial
- The metric values associated withthe best trial
- The parameters predicted to have the best optimization value without
violating any outcome constraints.
- The metric values for the best parameterization. Uses model prediction if
use_model_predictions=True, otherwise returns observed data.
- The trial which most recently ran the best parameterization
- The name of the best arm (each trial has a unique name associated with
each parameterization)
"""
...

if len(self._none_throws_experiment().trials) < 1:
raise UnsupportedError(
"No trials have been run yet. Please run at least one trial before "
"calling get_best_parameterization."
)

# Note: Using BestPointMixin directly instead of inheriting to avoid exposing
# unwanted public methods
trial_index, parameterization, model_prediction = none_throws(
BestPointMixin._get_best_trial(
experiment=self._none_throws_experiment(),
generation_strategy=self._generation_strategy_or_choose(),
use_model_predictions=use_model_predictions,
)
)

# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
# None but we do not allow this in the API.
return BestPointMixin._to_best_point_tuple(
experiment=self._none_throws_experiment(),
trial_index=trial_index,
parameterization=parameterization,
model_prediction=model_prediction,
)

def get_pareto_frontier(
self, use_model_predictions: bool = True
) -> dict[int, tuple[TParameterization, TOutcome]]:
) -> list[tuple[TParameterization, TOutcome, int, str]]:
"""
Calculates the in-sample Pareto frontier.
Identifies the parameterizations which are predicted to efficiently trade-off
between all objectives in a multi-objective optimization, also called the
in-sample Pareto frontier.
Returns:
A mapping of trial index to its parameterization and metric values.
"""
...
A list of tuples containing:
- The parameters predicted to have the best optimization value without
violating any outcome constraints.
- The metric values for the best parameterization. Uses model
prediction if use_model_predictions=True, otherwise returns
observed data.
- The trial which most recently ran the best parameterization
- The name of the best arm (each trial has a unique name associated
with each parameterization).
"""

if len(self._none_throws_experiment().trials) < 1:
raise UnsupportedError(
"No trials have been run yet. Please run at least one trial before "
"calling get_pareto_frontier."
)

frontier = BestPointMixin._get_pareto_optimal_parameters(
experiment=self._none_throws_experiment(),
# Requiring true GenerationStrategy here, ideally we will loosen this
# in the future
generation_strategy=assert_is_instance(
self._generation_strategy_or_choose(), GenerationStrategy
),
use_model_predictions=use_model_predictions,
)

# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
# None but we do not allow this in the API.
return [
BestPointMixin._to_best_point_tuple(
experiment=self._none_throws_experiment(),
trial_index=trial_index,
parameterization=parameterization,
model_prediction=model_prediction,
)
for trial_index, (parameterization, model_prediction) in frontier.items()
]

def predict(
self,
Expand Down Expand Up @@ -713,7 +788,7 @@ def predict(
for parameters in points
]
)
except (UserInputError, AssertionError) as e:
except (NotImplementedError, AssertionError) as e:
raise UnsupportedError(
"Predicting with the GenerationStrategy's modelbridge failed. This "
"could be because the current GenerationNode is not predictive -- try "
Expand Down
164 changes: 163 additions & 1 deletion ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,168 @@ def test_compute_analyses(self) -> None:
self.assertEqual(len(cards), 1)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot")

@mock_botorch_optimize
def test_get_best_parameterization(self) -> None:
client = Client()

client.configure_experiment(
experiment_config=ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")
# Set initialization_budget=3 so we can reach a predictive GenerationNode
# quickly
client.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig(initialization_budget=3)
)

with self.assertRaisesRegex(UnsupportedError, "No trials have been run yet"):
client.get_best_parameterization()

for _ in range(3):
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
},
)

parameters, prediction, index, name = client.get_best_parameterization()
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
)
)
self.assertEqual({*prediction.keys()}, {"foo"})

# Run a non-Sobol trial
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
},
)
parameters, prediction, index, name = client.get_best_parameterization()
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-fixme[6]
)
)
self.assertEqual({*prediction.keys()}, {"foo"})

@mock_botorch_optimize
def test_get_pareto_frontier(self) -> None:
client = Client()

client.configure_experiment(
experiment_config=ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo, bar")
# Set initialization_budget=3 so we can reach a predictive GenerationNode
# quickly
client.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig(initialization_budget=3)
)

with self.assertRaisesRegex(UnsupportedError, "No trials have been run yet"):
client.get_pareto_frontier()

for _ in range(3):
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
"bar": 0.0,
},
)

frontier = client.get_pareto_frontier(False)
for parameters, prediction, index, name in frontier:
self.assertEqual(
none_throws(
assert_is_instance(
client._none_throws_experiment().trials[index], Trial
).arm
).name,
name,
)
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
)
)
self.assertEqual({*prediction.keys()}, {"foo", "bar"})

# Run a non-Sobol trial
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
"bar": 0.0,
},
)
frontier = client.get_pareto_frontier()
for parameters, prediction, index, name in frontier:
self.assertEqual(
none_throws(
assert_is_instance(
client._none_throws_experiment().trials[index], Trial
).arm
).name,
name,
)
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-fixme[6]
)
)
self.assertEqual({*prediction.keys()}, {"foo", "bar"})

@mock_botorch_optimize
def test_predict(self) -> None:
client = Client()
Expand All @@ -919,7 +1081,7 @@ def test_predict(self) -> None:
)
)
client.configure_optimization(objective="foo", outcome_constraints=["bar >= 0"])
# Set num_initialization_trials=3 so we can reach a predictive GenerationNode
# Set initialization_budget=3 so we can reach a predictive GenerationNode
# quickly
client.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig(initialization_budget=3)
Expand Down
49 changes: 48 additions & 1 deletion ax/service/utils/best_point_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.trial import Trial
from ax.core.types import TModelPredictArm, TParameterization
from ax.exceptions.core import UserInputError
from ax.modelbridge.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -48,7 +49,7 @@
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from botorch.utils.multi_objective.box_decompositions import DominatedPartitioning
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws


logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -709,3 +710,49 @@ def percent_change(x: float, y: float, minimize: bool) -> float:
y=best_obj_value,
minimize=optimization_config.objective.minimize,
)

@staticmethod
def _to_best_point_tuple(
experiment: Experiment,
trial_index: int,
parameterization: TParameterization,
model_prediction: TModelPredictArm | None,
) -> tuple[TParameterization, dict[str, float | tuple[float, float]], int, str]:
"""
Return the tuple expected by the return signature of get_best_parameterization
and get_pareto_frontier in the Ax API.
TODO: Remove this helper when we clean up BestPointMixin.
Returns:
- The parameters predicted to have the best optimization value without
violating any outcome constraints.
- The metric values for the best parameterization. Uses model prediction if
use_model_predictions=True, otherwise returns observed data.
- The trial which most recently ran the best parameterization
- The name of the best arm (each trial has a unique name associated with
each parameterization)
"""

if model_prediction is not None:
mean, covariance = model_prediction

prediction: dict[str, float | tuple[float, float]] = {
metric_name: (
mean[metric_name],
none_throws(covariance)[metric_name][metric_name],
)
for metric_name in mean.keys()
}
else:
data_dict = experiment.lookup_data(trial_indices=[trial_index]).df.to_dict()

prediction: dict[str, float | tuple[float, float]] = {
data_dict["metric_name"][i]: (data_dict["mean"][i], data_dict["sem"][i])
for i in range(len(data_dict["metric_name"]))
}

trial = assert_is_instance(experiment.trials[trial_index], Trial)
arm = none_throws(trial.arm)

return parameterization, prediction, trial_index, arm.name

0 comments on commit 7f9b632

Please sign in to comment.