Skip to content

Commit

Permalink
Introduce BenchmarkTimeVaryingMetric for non-map metrics that vary ov…
Browse files Browse the repository at this point in the history
…er time (facebook#3184)

Summary:

**Context**: There are situations where a metric is not a MapMetric, but its value changes depending on time.

**This diff**:
Introduces `BenchmarkTimeVaryingMetric`, which is not a MapMetric, can consume data with multiple time steps, is available while running, and requires a backend simulator.

Reviewed By: Balandat

Differential Revision: D67224949
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 18, 2024
1 parent 7058716 commit 7580787
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 24 deletions.
45 changes: 36 additions & 9 deletions ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
Metrics vary on two dimensions: Whether they are `MapMetric`s or not, and
whether they are available while running or not.
There are two Metric classes:
- `BenchmarkMetric`: For when outputs should be `Data` (not `MapData`) and data
There are four Metric classes:
- `BenchmarkMetric`: A non-Map metric
is not available while running.
- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and
data is available while running.
There are further benchmark classes that are not yet implemented:
- `BenchmarkTimeVaryingMetric`: For when outputs should be `Data` and the metric
is available while running.
- `BenchmarkMapUnavailableWhileRunningMetric`: For when outputs should be
Expand Down Expand Up @@ -214,8 +212,10 @@ def _df_to_result(self, df: DataFrame) -> MetricFetchResult:

class BenchmarkMetric(BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is not available while
running.
Non-map Metric for benchmarking that is not available while running.
It cannot process data with multiple time steps, as it would only return one
value -- the value it has at completion time -- regardless.
"""

def _class_specific_metdata_validation(
Expand All @@ -234,12 +234,27 @@ def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
return Ok(value=Data(df=df.drop(columns=["step"])))


class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
class BenchmarkTimeVaryingMetric(BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is available while
running.
Non-Map Metric for benchmarking that is available while running.
It can produce different values at different times depending on when it is
called, using the `time` on a `BackendSimulator`.
"""

@classmethod
def is_available_while_running(cls) -> bool:
return True

def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
return Ok(
value=Data(df=df[df["step"] == df["step"].max()].drop(columns=["step"]))
)


class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
"""MapMetric for benchmarking. It is available while running."""

# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
Expand All @@ -253,3 +268,15 @@ def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
# Just in case the key was renamed by a subclass
df = df.rename(columns={"step": self.map_key_info.key})
return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))


class BenchmarkMapUnavailableWhileRunningMetric(MapMetric, BenchmarkMetricBase):
# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
# Just in case the key was renamed by a subclass
df = df.rename(columns={"step": self.map_key_info.key})
return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))
84 changes: 69 additions & 15 deletions ax/benchmark/tests/test_benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from ax.benchmark.benchmark_metric import (
_get_no_metadata_msg,
BenchmarkMapMetric,
BenchmarkMapUnavailableWhileRunningMetric,
BenchmarkMetric,
BenchmarkMetricBase,
BenchmarkTimeVaryingMetric,
)
from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata
from ax.core.arm import Arm
Expand Down Expand Up @@ -113,10 +116,20 @@ def setUp(self) -> None:
name: BenchmarkMetric(name=name, lower_is_better=True)
for name in self.outcome_names
}
self.tv_metrics = {
name: BenchmarkTimeVaryingMetric(name=name, lower_is_better=True)
for name in self.outcome_names
}
self.map_metrics = {
name: BenchmarkMapMetric(name=name, lower_is_better=True)
for name in self.outcome_names
}
self.map_unavailable_while_running_metrics = {
name: BenchmarkMapUnavailableWhileRunningMetric(
name=name, lower_is_better=True
)
for name in self.outcome_names
}

def test_available_while_running(self) -> None:
self.assertFalse(self.metrics["test_metric1"].is_available_while_running())
Expand All @@ -125,10 +138,24 @@ def test_available_while_running(self) -> None:
self.assertTrue(self.map_metrics["test_metric1"].is_available_while_running())
self.assertTrue(BenchmarkMapMetric.is_available_while_running())

self.assertTrue(self.tv_metrics["test_metric1"].is_available_while_running())
self.assertTrue(BenchmarkTimeVaryingMetric.is_available_while_running())

self.assertFalse(
self.map_unavailable_while_running_metrics[
"test_metric1"
].is_available_while_running()
)
self.assertFalse(
BenchmarkMapUnavailableWhileRunningMetric.is_available_while_running()
)

def test_exceptions(self) -> None:
for metric in [
self.metrics["test_metric1"],
self.map_metrics["test_metric1"],
self.tv_metrics["test_metric1"],
self.map_unavailable_while_running_metrics["test_metric1"],
]:
with self.subTest(
f"No-metadata error, metric class={metric.__class__.__name__}"
Expand Down Expand Up @@ -164,7 +191,7 @@ def test_exceptions(self) -> None:
self.metrics["test_metric1"].fetch_trial_data(trial=trial)

def _test_fetch_trial_data_one_time_step(
self, batch: bool, metric: BenchmarkMetric | BenchmarkMapMetric
self, batch: bool, metric: BenchmarkMetricBase
) -> None:
trial = get_test_trial(batch=batch, has_simulator=True)
df1 = assert_is_instance(metric.fetch_trial_data(trial=trial).value, Data).df
Expand All @@ -179,7 +206,12 @@ def _test_fetch_trial_data_one_time_step(
def test_fetch_trial_data_one_time_step(self) -> None:
for batch, metrics in product(
[False, True],
[self.metrics, self.map_metrics],
[
self.metrics,
self.map_metrics,
self.tv_metrics,
self.map_unavailable_while_running_metrics,
],
):
metric = metrics["test_metric1"]
with self.subTest(
Expand All @@ -195,37 +227,59 @@ def _test_fetch_trial_multiple_time_steps_with_simulator(self, batch: bool) -> N
- Has simulator, metric is 'BenchmarkMapMetric' -> df grows with each step
- No simulator, metric is 'BenchmarkMapMetric' -> all data present while
running (but realistically it would never be RUNNING)
- Has simulator, metric is 'BenchmarkTimeVaryingMetric' -> one step at
a time, evolving as we take steps
- No simulator, metric is 'BenchmarkTimeVaryingMetric' -> completes
immediately and returns last step
See table in benchmark_metric.py for more details.
"""
metric_name = "test_metric1"

# Iterating over list of length 1 here because there will be more
# metrics in the next diff
for metric in [self.map_metrics[metric_name]]:
trial_with_simulator = get_test_trial(
for metric, has_simulator in product(
[
self.map_metrics[metric_name],
self.tv_metrics[metric_name],
self.map_unavailable_while_running_metrics[metric_name],
],
[False, True],
):
trial = get_test_trial(
has_metadata=True,
batch=batch,
multiple_time_steps=True,
has_simulator=True,
has_simulator=has_simulator,
)
data = metric.fetch_trial_data(trial=trial_with_simulator).value
data = metric.fetch_trial_data(trial=trial).value
df_or_map_df = data.map_df if isinstance(metric, MapMetric) else data.df
self.assertEqual(len(df_or_map_df), len(trial_with_simulator.arms))
returns_full_data = (not has_simulator) and isinstance(metric, MapMetric)
self.assertEqual(
len(df_or_map_df), len(trial.arms) * (3 if returns_full_data else 1)
)
drop_cols = ["virtual runtime"]
if not isinstance(metric, MapMetric):
drop_cols += ["step"]

expected_df = _get_one_step_df(
batch=batch, metric_name=metric_name, step=0
).drop(columns=drop_cols)
self.assertEqual(df_or_map_df.to_dict(), expected_df.to_dict())
if returns_full_data:
self.assertEqual(
df_or_map_df[df_or_map_df["step"] == 0].to_dict(),
expected_df.to_dict(),
)
else:
self.assertEqual(df_or_map_df.to_dict(), expected_df.to_dict())

backend_simulator = trial_with_simulator.run_metadata[
backend_simulator = trial.run_metadata[
"benchmark_metadata"
].backend_simulator
self.assertEqual(backend_simulator is None, not has_simulator)
if backend_simulator is None:
continue
self.assertEqual(backend_simulator.time, 0)
sim_trial = none_throws(
backend_simulator.get_sim_trial_by_index(trial_with_simulator.index)
backend_simulator.get_sim_trial_by_index(trial.index)
)
self.assertIn(sim_trial, backend_simulator._running)
backend_simulator.update()
Expand All @@ -234,13 +288,13 @@ def _test_fetch_trial_multiple_time_steps_with_simulator(self, batch: bool) -> N
backend_simulator.update()
self.assertIn(sim_trial, backend_simulator._completed)
self.assertEqual(backend_simulator.time, 2)
data = metric.fetch_trial_data(trial=trial_with_simulator).value
data = metric.fetch_trial_data(trial=trial).value
if isinstance(metric, MapMetric):
map_df = data.map_df
self.assertEqual(len(map_df), 2 * len(trial_with_simulator.arms))
self.assertEqual(len(map_df), 2 * len(trial.arms))
self.assertEqual(set(map_df["step"].tolist()), {0, 1})
df = data.df
self.assertEqual(len(df), len(trial_with_simulator.arms))
self.assertEqual(len(df), len(trial.arms))
expected_df = _get_one_step_df(
batch=batch, metric_name=metric_name, step=1
).drop(columns=drop_cols)
Expand Down

0 comments on commit 7580787

Please sign in to comment.