Skip to content

Commit

Permalink
update BestPointMixin to support BatchTrial in benchmarks (facebook#2014
Browse files Browse the repository at this point in the history
)

Summary:

see title

Reviewed By: esantorella

Differential Revision: D51534384
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 29, 2023
1 parent 14a5615 commit fc20d4f
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 40 deletions.
50 changes: 50 additions & 0 deletions ax/service/tests/test_best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
import pandas as pd

from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.core.trial import Trial
from ax.exceptions.core import DataRequiredError
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.core_stubs import (
get_arm_weights2,
get_arms_from_dict,
get_experiment_with_batch_trial,
get_experiment_with_observations,
get_experiment_with_trial,
)
Expand Down Expand Up @@ -95,6 +100,51 @@ def test_get_trace(self) -> None:
exp = get_experiment_with_trial()
self.assertEqual(get_trace(exp), [])

# test batch trial
exp = get_experiment_with_batch_trial()
trial = exp.trials[0]
exp.optimization_config.outcome_constraints[0].relative = False
trial.mark_running(no_runner_required=True).mark_completed()
df_dict = []
for i, arm in enumerate(trial.arms):
df_dict.extend(
[
{
"trial_index": 0,
"metric_name": m,
"arm_name": arm.name,
"mean": float(i),
"sem": 0.0,
}
for m in exp.metrics.keys()
]
)
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict)))
self.assertEqual(get_trace(exp), [len(trial.arms) - 1])
# test that there is performance metric in the trace for each
# completed/early-stopped trial
trial1 = checked_cast(BatchTrial, trial).clone_to()
trial1.mark_abandoned()
arms = get_arms_from_dict(get_arm_weights2())
trial2 = exp.new_batch_trial(GeneratorRun(arms))
trial2.mark_running(no_runner_required=True).mark_completed()
df_dict2 = []
for i, arm in enumerate(trial2.arms):
df_dict2.extend(
[
{
"trial_index": 2,
"metric_name": m,
"arm_name": arm.name,
"mean": 10 * float(i),
"sem": 0.0,
}
for m in exp.metrics.keys()
]
)
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
self.assertEqual(get_trace(exp), [2, 20.0])

def test_get_hypervolume(self) -> None:
# W/ empty data.
exp = get_experiment_with_trial()
Expand Down
57 changes: 39 additions & 18 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def test_extract_Y_from_data(self) -> None:
"sem": 0.0,
}
)
df_0 = df_dicts[:2]
experiment.attach_data(Data(df=pd.DataFrame.from_records(df_dicts)))

expected_Y = torch.stack(
Expand All @@ -372,39 +371,47 @@ def test_extract_Y_from_data(self) -> None:
],
dim=-1,
)
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
)
expected_trial_indices = torch.arange(20)
self.assertTrue(torch.allclose(Y, expected_Y))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))
# Check that it respects ordering of metric names.
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["bar", "foo"],
)
self.assertTrue(torch.allclose(Y, expected_Y[:, [1, 0]]))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))
# Extract partial metrics.
Y = extract_Y_from_data(experiment=experiment, metric_names=["bar"])
Y, trial_indices = extract_Y_from_data(
experiment=experiment, metric_names=["bar"]
)
self.assertTrue(torch.allclose(Y, expected_Y[:, [1]]))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))
# Works with messed up ordering of data.
clone_dicts = df_dicts.copy()
random.shuffle(clone_dicts)
experiment._data_by_trial = {}
experiment.attach_data(Data(df=pd.DataFrame.from_records(clone_dicts)))
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
)
self.assertTrue(torch.allclose(Y, expected_Y))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices))

# Check that it skips trials that are not completed.
experiment.trials[0].mark_running(no_runner_required=True, unsafe=True)
experiment.trials[1].mark_abandoned(unsafe=True)
Y = extract_Y_from_data(
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
)
self.assertTrue(torch.allclose(Y, expected_Y[2:]))
self.assertTrue(torch.equal(trial_indices, expected_trial_indices[2:]))

# Error with missing data.
with self.assertRaisesRegex(
Expand All @@ -420,28 +427,42 @@ def test_extract_Y_from_data(self) -> None:

# Error with extra data.
with self.assertRaisesRegex(
UserInputError, "Trial data has more than one row per metric. "
UserInputError, "Trial data has more than one row per arm, metric pair. "
):
# Skipping first 5 data points since first two trials are not completed.
base_df = pd.DataFrame.from_records(df_dicts[5:])

extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
data=Data(df=pd.concat((base_df, base_df))),
)

# Check that it errors with BatchTrial.
# Check that it works with BatchTrial.
experiment = get_branin_experiment()
BatchTrial(experiment=experiment, index=0).mark_running(
no_runner_required=True
).mark_completed()
with self.assertRaisesRegex(UnsupportedError, "BatchTrials are not supported."):
extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
data=Data(df=pd.DataFrame.from_records(df_0)),
)
batch_trial = BatchTrial(experiment=experiment, index=0)
batch_trial.add_arm(Arm(name="0_0", parameters={"x1": 0.0, "x2": 0.0}))
batch_trial.add_arm(Arm(name="0_1", parameters={"x1": 1.0, "x2": 0.0}))
batch_trial.mark_running(no_runner_required=True).mark_completed()
df_dicts_batch = []
for i in (0, 1):
for metric_name in ["foo", "bar"]:
df_dicts_batch.append(
{
"trial_index": 0,
"metric_name": metric_name,
"arm_name": f"0_{i}",
"mean": float(i) if metric_name == "foo" else i + 5.0,
"sem": 0.0,
}
)
batch_df = pd.DataFrame.from_records(df_dicts_batch)
Y, trial_indices = extract_Y_from_data(
experiment=experiment,
metric_names=["foo", "bar"],
data=Data(df=batch_df),
)
self.assertTrue(torch.allclose(Y, expected_Y[:2]))
self.assertTrue(torch.equal(trial_indices, torch.zeros(2, dtype=torch.long)))

def test_is_row_feasible(self) -> None:
exp = get_experiment_with_observations(
Expand Down
28 changes: 17 additions & 11 deletions ax/service/utils/best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def extract_Y_from_data(
experiment: Experiment,
metric_names: List[str],
data: Optional[Data] = None,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
r"""Converts the experiment observation data into a tensor.
NOTE: This requires block design for observations. It will
Expand All @@ -796,11 +796,14 @@ def extract_Y_from_data(
each `trial_index` in the `data`.
Returns:
A tensor of observed metrics.
A two-element Tuple containing a tensor of observed metrics and a
tensor of trial_indices.
"""
df = data.df if data is not None else experiment.lookup_data().df
if len(df) == 0:
return torch.empty(0, len(metric_names), dtype=torch.double)
y = torch.empty(0, len(metric_names), dtype=torch.double)
indices = torch.empty(0, dtype=torch.long)
return y, indices

trials_to_use = []
data_to_use = df[df["metric_name"].isin(metric_names)]
Expand All @@ -810,12 +813,10 @@ def extract_Y_from_data(
if trial.status not in [TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED]:
# Skip trials that are not completed or early stopped.
continue
if isinstance(trial, BatchTrial):
raise UnsupportedError("BatchTrials are not supported.")
trials_to_use.append(trial_idx)
if len(trial_data) > len(set(trial_data["metric_name"])):
if trial_data[["metric_name", "arm_name"]].duplicated().any():
raise UserInputError(
"Trial data has more than one row per metric. "
"Trial data has more than one row per arm, metric pair. "
f"Got\n\n{trial_data}\n\nfor trial {trial_idx}."
)
# We have already ensured that `trial_data` has no metrics not in
Expand All @@ -830,13 +831,18 @@ def extract_Y_from_data(
keeps = df["trial_index"].isin(trials_to_use)

if not keeps.any():
return torch.empty(0, len(metric_names), dtype=torch.double)
return torch.empty(0, len(metric_names), dtype=torch.double), torch.empty(
0, dtype=torch.long
)

data_as_wide = df[keeps].pivot(
columns="metric_name", index="trial_index", values="mean"
columns="metric_name", index=["trial_index", "arm_name"], values="mean"
)[metric_names]

return torch.tensor(data_as_wide.to_numpy()).to(torch.double)
means = torch.tensor(data_as_wide.to_numpy()).to(torch.double)
trial_indices = torch.tensor(
data_as_wide.reset_index()["trial_index"].to_numpy(), dtype=torch.long
)
return means, trial_indices


def _objective_threshold_from_nadir(
Expand Down
56 changes: 45 additions & 11 deletions ax/service/utils/best_point_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,24 @@ def _get_trace(
experiment: Experiment,
optimization_config: Optional[OptimizationConfig] = None,
) -> List[float]:
"""Compute the optimization trace at each iteration.
Given an experiment and an optimization config, compute the performance
at each iteration. For multi-objective, the performance is compute as the
hypervolume. For single objective, the performance is compute as the best
observed objective value.
An iteration here refers to a completed or early-stopped (batch) trial.
There will be one performance metric in the trace for each iteration.
Args:
experiment: The experiment to get the trace for.
optimization_config: Optimization config to use in place of the one
stored on the experiment.
Returns:
A list of performance values at each iteration.
"""
optimization_config = optimization_config or not_none(
experiment.optimization_config
)
Expand All @@ -437,7 +455,9 @@ def _get_trace(
metric_names.update({cons.metric.name})
metric_names = list(metric_names)
# Convert data into a tensor.
Y = extract_Y_from_data(experiment=experiment, metric_names=metric_names)
Y, trial_indices = extract_Y_from_data(
experiment=experiment, metric_names=metric_names
)
if Y.numel() == 0:
return []

Expand Down Expand Up @@ -508,26 +528,40 @@ def _get_trace(
feas = torch.all(torch.stack([c(Y) <= 0 for c in cons_tfs], dim=-1), dim=-1)
# Set the infeasible points to reference point or the worst observed value.
Y_obj[~feas] = infeas_value
# Get unique trial indices. Note: only completed/early-stopped
# trials are present.
unique_trial_indices = trial_indices.unique().sort().values.tolist()
# compute the performance at each iteration (completed/early-stopped
# trial).
# For `BatchTrial`s, there is one performance value per iteration, even
# if the iteration (`BatchTrial`) has multiple arms.
if optimization_config.is_moo_problem:
# Compute the hypervolume trace.
partitioning = DominatedPartitioning(
ref_point=weighted_objective_thresholds.double()
)
# compute hv at each iteration
# compute hv for each iteration (trial_index)
hvs = []
for Yi in Y_obj.split(1):
for trial_index in unique_trial_indices:
new_Y = Y_obj[trial_indices == trial_index]
# update with new point
partitioning.update(Y=Yi)
partitioning.update(Y=new_Y)
hv = partitioning.compute_hypervolume().item()
hvs.append(hv)
return hvs
else:
# Find the best observed value.
raw_maximum = np.maximum.accumulate(Y_obj.cpu().numpy())
if optimization_config.objective.minimize:
# Negate the result if it is a minimization problem.
raw_maximum = -raw_maximum
return raw_maximum.tolist()
running_max = float("-inf")
raw_maximum = np.zeros(len(unique_trial_indices))
# Find the best observed value for each iterations.
# Enumerate the unique trial indices because only indices
# of completed/early-stopped trials are present.
for i, trial_index in enumerate(unique_trial_indices):
new_Y = Y_obj[trial_indices == trial_index]
running_max = max(running_max, new_Y.max().item())
raw_maximum[i] = running_max
if optimization_config.objective.minimize:
# Negate the result if it is a minimization problem.
raw_maximum = -raw_maximum
return raw_maximum.tolist()

@staticmethod
def _get_trace_by_progression(
Expand Down

0 comments on commit fc20d4f

Please sign in to comment.