Skip to content

Commit

Permalink
Introduce BenchmarkTimeVaryingMetric for non-map metrics that vary ov…
Browse files Browse the repository at this point in the history
…er time

Summary:
**Context**: There are situations where a metric is not a MapMetric, but its value changes depending on time.

**This diff**:
Introduces `BenchmarkTimeVaryingMetric`, which is not a MapMetric, can consume data with multiple time steps, is available while running, and requires a backend simulator.

Differential Revision: D67224949
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 16, 2024
1 parent c2c9522 commit 0079197
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
47 changes: 42 additions & 5 deletions ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
"""
Metric classes for Ax benchmarking.
There are two Metric classes:
There are three Metric classes:
- `BenchmarkMetric`: For when outputs should be `Data` (not `MapData`) and data
is not available while running.
- `BenchmarkTimeVaryingMetric`
- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and
data is available while running.
"""
Expand Down Expand Up @@ -148,8 +149,10 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult

class BenchmarkMetric(BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is not available while
running.
Non-map Metric for benchmarking that is not available while running.
It cannot process data with multiple time steps, as it would only return one
value -- the value it has at completion time -- regardless.
"""

def _class_specific_metadata_validation(
Expand All @@ -164,10 +167,33 @@ def _class_specific_metadata_validation(
)


class BenchmarkTimeVaryingMetric(BenchmarkMetricBase):
"""
Non-Map Metric for benchmarking that is available while running.
It can produce different values at different times depending on when it is
called, using the `time` on a `BackendSimulator`. It cannot process trial
metadata that does not include a `BackendSimulator`.
"""

def _class_specific_metadata_validation(
self, metadata: BenchmarkTrialMetadata, trial_index: int
) -> None:
if metadata.backend_simulator is None:
raise ValueError(
f"Trial {trial_index} has no `backend_simulator`. This is not "
f"supported by `{self.__class__.__name__}` because it is "
"available while running; use `BenchmarkMetric`."
)

@classmethod
def is_available_while_running(cls) -> bool:
return True


class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is available while
running.
MapMetric for benchmarking. It is available while running.
"""

# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
Expand All @@ -178,3 +204,14 @@ class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
@classmethod
def is_available_while_running(cls) -> bool:
return True


class BenchmarkMapNotAvailableWhileRunningMetric(MapMetric, BenchmarkMetricBase):
# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

@classmethod
def is_available_while_running(cls) -> bool:
return False
32 changes: 28 additions & 4 deletions ax/benchmark/tests/test_benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BenchmarkMapMetric,
BenchmarkMetric,
BenchmarkMetricBase,
BenchmarkTimeVaryingMetric,
)
from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata
from ax.core.arm import Arm
Expand Down Expand Up @@ -114,6 +115,10 @@ def setUp(self) -> None:
name: BenchmarkMetric(name=name, lower_is_better=True)
for name in self.outcome_names
}
self.tv_metrics = {
name: BenchmarkTimeVaryingMetric(name=name, lower_is_better=True)
for name in self.outcome_names
}
self.map_metrics = {
name: BenchmarkMapMetric(name=name, lower_is_better=True)
for name in self.outcome_names
Expand All @@ -126,9 +131,13 @@ def test_available_while_running(self) -> None:
self.assertTrue(self.map_metrics["test_metric1"].is_available_while_running())
self.assertTrue(BenchmarkMapMetric.is_available_while_running())

self.assertTrue(self.tv_metrics["test_metric1"].is_available_while_running())
self.assertTrue(BenchmarkTimeVaryingMetric.is_available_while_running())

def test_exceptions(self) -> None:
for metric in [
self.metrics["test_metric1"],
self.tv_metrics["test_metric1"],
self.map_metrics["test_metric1"],
]:
with self.subTest(
Expand Down Expand Up @@ -169,7 +178,7 @@ def _test_fetch_trial_data_one_time_step(
def test_fetch_trial_data_one_time_step(self) -> None:
for batch, metrics in product(
[False, True],
[self.metrics, self.map_metrics],
[self.metrics, self.tv_metrics, self.map_metrics],
):
metric = metrics["test_metric1"]
with self.subTest(
Expand All @@ -182,15 +191,17 @@ def _test_fetch_trial_multiple_time_steps_with_simulator(self, batch: bool) -> N
"""
Cases for fetching data with multiple time steps:
- Metric is 'BenchmarkMetric' -> exception, tested below
- Has simulator, metric is 'BenchmarkTimeVaryingMetric' -> one step at
a time, evolving as we take steps
- Has simulator, metric is 'BenchmarkMapMetric' -> df grows with each step
- No simulator, metric is 'BenchmarkTimeVaryingMetric' -> exception,
tested below
- No simulator, metric is 'BenchmarkMapMetric' -> all data present while
running (this behavior may be undesirable)
"""
metric_name = "test_metric1"

# Iterating over list of length 1 here because there will be more
# metrics in the next diff
for metric in [self.map_metrics[metric_name]]:
for metric in [self.map_metrics[metric_name], self.tv_metrics[metric_name]]:
trial_with_simulator = get_test_trial(
has_metadata=True,
batch=batch,
Expand Down Expand Up @@ -245,6 +256,19 @@ def test_fetch_trial_multiple_time_steps_with_simulator(self) -> None:
self._test_fetch_trial_multiple_time_steps_with_simulator(batch=False)
self._test_fetch_trial_multiple_time_steps_with_simulator(batch=True)

def test_fetch_time_varying_metric_without_simulator(self) -> None:
metric_name = "test_metric1"
metric = self.tv_metrics[metric_name]

trial = get_test_trial(
has_metadata=True,
batch=False,
multiple_time_steps=True,
has_simulator=False,
)
with self.assertRaisesRegex(ValueError, "has no `backend_simulator`"):
metric.fetch_trial_data(trial=trial)

def test_sim_trial_completes_in_future_raises(self) -> None:
simulator = BackendSimulator()
simulator.run_trial(trial_index=0, runtime=0)
Expand Down

0 comments on commit 0079197

Please sign in to comment.