Skip to content

Commit

Permalink
Implement get_best_arm, get_pareto_frontier (facebook#3142)
Browse files Browse the repository at this point in the history
Summary:

Implement "best point" functionality in Client. A few notes:
* Renamed get_best_trial --> get_best_arm: this is more accurate to what it is actually doing (calculating the best in sample point) and renaming will allow us to continue using this method when we create BatchClient
* If GS is not on a predictive step we return {} for the prediction term in both get_best_arm and get_pareto_frontier as well as log an error

Differential Revision: D66702545
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 9, 2024
1 parent 3f4d07e commit c3cdd3b
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 11 deletions.
162 changes: 151 additions & 11 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
optimization_config_from_string,
)
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.utils.common.logger import get_logger
from ax.utils.common.random import with_rng_seed
from pyre_extensions import assert_is_instance, none_throws
Expand Down Expand Up @@ -667,29 +668,168 @@ def compute_analyses(

return cards

def get_best_trial(
def get_best_arm(
self, use_model_predictions: bool = True
) -> tuple[int, TParameterization, TOutcome]:
) -> tuple[str, TParameterization, TOutcome]:
"""
Calculates the best in-sample trial.
Identifies the best parameterization tried in the experiment so far, the best
in-sample arm.
If `use_model_predictions` is True, first attempts to do so with the model used
in optimization and its corresponding predictions if available. If
`use_model_predictions` is False or attempts to use the model fails, falls back
to the best raw objective based on the data fetched from the experiment.
Arms which violate outcome constraints are not eligible to be the best arm.
Returns:
- The index of the best trial
- The parameters of the best trial
- The metric values associated withthe best trial
- The name of the best arm
- The parameters of the best arm
- The metric values for the best arm. Uses model prediction if
use_model_predictions=True, otherwise returns observed data.
"""
...

if len(self._none_throws_experiment().trials) < 1:
raise UnsupportedError(
"No trials have been run yet. Please run at least one trial before "
"calling get_best_arm."
)

# Note: Using BestPointMixin directly instead of inheriting to avoid exposing
# unwanted public methods
trial_index, parameters, _ = none_throws(
BestPointMixin._get_best_trial(
experiment=self._none_throws_experiment(),
# Requiring true GenerationStrategy here, ideally we will loosen this
# in the future
generation_strategy=assert_is_instance(
self._generation_strategy_or_choose(), GenerationStrategy
),
use_model_predictions=use_model_predictions,
)
)

arm = none_throws(
assert_is_instance(
self._none_throws_experiment().trials[trial_index], Trial
).arm
)

if use_model_predictions:
try:
# pyre-fixme[6]: Core Ax allows users to specify TParameterization
# values as None but we do not allow this in the API.
prediction = self.predict(points=[parameters])[0]
except UnsupportedError:
logger.error(
"Model predictions are not available, returning empty prediction"
)

prediction = {}
else:
data_dict = (
self._none_throws_experiment()
.lookup_data(trial_indices=[trial_index])
.df.to_dict()
)

prediction = {
data_dict["metric_name"][i]: (data_dict["mean"][i], data_dict["sem"][i])
for i in range(len(data_dict["metric_name"]))
}

# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
# None but we do not allow this in the API.
return arm.name, parameters, prediction

def get_pareto_frontier(
self, use_model_predictions: bool = True
) -> dict[int, tuple[TParameterization, TOutcome]]:
) -> dict[str, tuple[TParameterization, TOutcome]]:
"""
Calculates the in-sample Pareto frontier.
Returns:
A mapping of trial index to its parameterization and metric values.
A mapping of the arm name to the parameterization and predicted or observed
outcome.
"""
...

if len(self._none_throws_experiment().trials) < 1:
raise UnsupportedError(
"No trials have been run yet. Please run at least one trial before "
"calling get_pareto_frontier."
)

frontier = BestPointMixin._get_pareto_optimal_parameters(
experiment=self._none_throws_experiment(),
# Requiring true GenerationStrategy here, ideally we will loosen this
# in the future
generation_strategy=assert_is_instance(
self._generation_strategy_or_choose(), GenerationStrategy
),
use_model_predictions=use_model_predictions,
)

frontier_list = [*frontier.items()]

arm_names = [
none_throws(
assert_is_instance(
self._none_throws_experiment().trials[trial_index], Trial
).arm
).name
for trial_index, _ in frontier_list
]

if use_model_predictions:
try:
predictions = self.predict(
# pyre-fixme[6]: Core Ax allows users to specify TParameterization
# values as None but we do not allow this in the API.
points=[value[0] for _key, value in frontier_list]
)
except UnsupportedError:
logger.error(
"Model predictions are not available, returning empty prediction"
)

predictions: list[TOutcome] = [{} for _ in frontier]
else:
predictions = []
for trial_index in frontier.keys():
data_dict = (
self._none_throws_experiment()
.lookup_data(trial_indices=[trial_index])
.df.to_dict()
)

predictions.append(
{
data_dict["metric_name"][i]: (
data_dict["mean"][i],
data_dict["sem"][i],
)
for i in range(len(data_dict["metric_name"]))
}
)

try:
predictions = self.predict(
# pyre-fixme[6]: Core Ax allows users to specify TParameterization
# values as None but we do not allow this in the API.
points=[value[0] for _key, value in frontier_list]
)
except UnsupportedError:
logger.error(
"Model predictions are not available, returning empty prediction"
)
predictions: list[TOutcome] = [{} for _ in frontier]

# pyre-fixme[7]: Core Ax allows users to specify TParameterization
# values as None but we do not allow this in the API.
return {
arm_names[i]: (frontier_list[i][1][0], predictions[i])
for i in range(len(frontier_list))
}

def predict(
self,
Expand Down Expand Up @@ -723,7 +863,7 @@ def predict(
for parameters in points
]
)
except (UserInputError, AssertionError) as e:
except (NotImplementedError, AssertionError) as e:
raise UnsupportedError(
"Predicting with the GenerationStrategy's modelbridge failed. This "
"could be because the current GenerationNode is not predictive -- try "
Expand Down
153 changes: 153 additions & 0 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,159 @@ def test_compute_analyses(self) -> None:
self.assertEqual(len(cards), 1)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot")

def test_get_best_arm(self) -> None:
client = Client()

client.configure_experiment(
experiment_config=ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo")
# Set num_initialization_trials=3 so we can reach a predictive GenerationNode
# quickly
client.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig(
num_initialization_trials=3
)
)

with self.assertRaisesRegex(UnsupportedError, "No trials have been run yet"):
client.get_best_arm()

for _ in range(3):
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
},
)

name, parameters, prediction = client.get_best_arm()
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
)
)
self.assertEqual(prediction, {}) # No prediction since we are still in Sobol

# Run a non-Sobol trial
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
},
)
name, parameters, prediction = client.get_best_arm()
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
)
)
self.assertEqual({*prediction.keys()}, {"foo"})

def test_get_pareto_frontier(self) -> None:
client = Client()

client.configure_experiment(
experiment_config=ExperimentConfig(
parameters=[
RangeParameterConfig(
name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1)
),
],
name="foo",
)
)
client.configure_optimization(objective="foo, bar")
# Set num_initialization_trials=3 so we can reach a predictive GenerationNode
# quickly
client.configure_generation_strategy(
generation_strategy_config=GenerationStrategyConfig(
num_initialization_trials=3
)
)

with self.assertRaisesRegex(UnsupportedError, "No trials have been run yet"):
client.get_pareto_frontier()

for _ in range(3):
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
"bar": 0.0,
},
)

frontier = client.get_pareto_frontier(False)
for name, point in frontier.items():
parameters, prediction = point

self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
)
)
self.assertEqual(
prediction, {}
) # No prediction since we are still in Sobol

# Run a non-Sobol trial
for index, parameters in client.get_next_trials(maximum_trials=1).items():
client.complete_trial(
trial_index=index,
raw_data={
"foo": assert_is_instance(parameters["x1"], float) ** 2,
"bar": 0.0,
},
)
frontier = client.get_pareto_frontier()
for name, point in frontier.items():
parameters, prediction = point
self.assertIn(
name,
[
none_throws(assert_is_instance(trial, Trial).arm).name
for trial in client._none_throws_experiment().trials.values()
],
)
self.assertTrue(
client._none_throws_experiment().search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
)
)
self.assertEqual({*prediction.keys()}, {"foo", "bar"})

@mock_botorch_optimize
def test_predict(self) -> None:
client = Client()
Expand Down

0 comments on commit c3cdd3b

Please sign in to comment.