Skip to content

Commit

Permalink
Refactor benchmark metric classes to increase code sharing (facebook#…
Browse files Browse the repository at this point in the history
…3182)

Summary:

**Context**: The two benchmark metric classes, `BenchmarkMetric` and `BenchmarkMapMetric`, share some code that is currently duplicated. I initially didn't want to have them inherit from a common `BenchmarkMetric` class, since this would cause a potential diamond inheritance issue in `BenchmarkMapMetric`. However, I don't see this issue as very risky, and introducing a base class will make it easy to add  classes we don't currently have: a non-Map metric that is available while running, and a Map metric that is not available while running.

Note that there are exactly four possible benchmark metric classes (map data vs. not, available while running vs. not) and these cannot be consolidated into fewer classes, since metrics must inherit from MapMetric if and only if they produce MapData, and `available_while_running` is a class method.

**This diff**:
* Introduces a base class `BenchmarkMetricBase`, consolidates logic into it, and makes `BenchmarkMetric` and `BenchmarkMapMetric` inherit from it

`BenchmarkMetricBase.fetch_trial_data` may appear unnecessarily complex for the two classes we have now, but will not require any changes to add two more classes.

Differential Revision: D67254165
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 16, 2024
1 parent f71703f commit 63ef940
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 124 deletions.
228 changes: 105 additions & 123 deletions ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,20 @@

# pyre-strict

"""
Metric classes for Ax benchmarking.
There are two Metric classes:
- `BenchmarkMetric`: For when outputs should be `Data` (not `MapData`) and data
is not available while running.
- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and
data is available while running.
"""

from typing import Any

from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata

from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
Expand All @@ -22,40 +34,7 @@ def _get_no_metadata_msg(trial_index: int) -> str:
return f"No metadata available for trial {trial_index}."


def _get_no_metadata_err(trial: BaseTrial) -> Err[Data, MetricFetchE]:
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)


def _validate_trial_and_kwargs(
trial: BaseTrial, class_name: str, **kwargs: Any
) -> None:
"""
Validate that:
- Kwargs are empty
- No arms within a BatchTrial have been abandoned
"""
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
"BenchmarkMetric does not support abandoned arms in batch trials."
)


class BenchmarkMetric(Metric):
"""A generic metric used for observed values produced by Ax Benchmarks.
Compatible with results generated by `BenchmarkRunner`.
"""

class BenchmarkMetricBase(Metric):
def __init__(
self,
name: str,
Expand All @@ -77,122 +56,125 @@ def __init__(
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)

df = trial.run_metadata["benchmark_metadata"].dfs[self.name]
if df["step"].nunique() > 1:
raise ValueError(
f"Trial {trial.index} has data from multiple time steps. This is"
" not supported by `BenchmarkMetric`; use `BenchmarkMapMetric`."
)
df = df.drop(columns=["step", "virtual runtime"])
if not self.observe_noise_sd:
df["sem"] = None
return Ok(value=Data(df=df))


class BenchmarkMapMetric(MapMetric):
# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

def __init__(
self,
name: str,
# Needed to be boolean (not None) for validation of MOO opt configs
lower_is_better: bool,
observe_noise_sd: bool = True,
def _class_specific_metadata_validation(
self, metadata: BenchmarkTrialMetadata, trial_index: int
) -> None:
"""
Args:
name: Name of the metric.
lower_is_better: If `True`, lower metric values are considered better.
observe_noise_sd: If `True`, the standard deviation of the observation
noise is included in the `sem` column of the the returned data.
If `False`, `sem` is set to `None` (meaning that the model will
have to infer the noise level).
"""
super().__init__(name=name, lower_is_better=lower_is_better)
# Declare `lower_is_better` as bool (rather than optional as in the base class)
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd

@classmethod
def is_available_while_running(cls) -> bool:
return True
"""Check whether this metadata is valid for use with `fetch_trial_data`."""
return

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
If the trial has been completed, look up the ``sim_start_time`` and
``sim_completed_time`` on the corresponding ``SimTrial``, and return all
data from keys 0, ..., ``sim_completed_time - sim_start_time``. If the
trial has not completed, return all data from keys 0, ..., ``sim_runtime
- sim_start_time``.
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)

class_name = self.__class__.__name__
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
f"{self.__class__.__name__} does not support abandoned arms in "
"batch trials."
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)

metadata = trial.run_metadata["benchmark_metadata"]
self._class_specific_metadata_validation(
metadata=metadata, trial_index=trial.index
)

backend_simulator = metadata.backend_simulator
df = metadata.dfs[self.name]

# Filter out the observable data
if backend_simulator is None:
max_t = float("inf")
# If there's no backend simulator (so the metric is not available while
# running) then no filtering is needed
available_data = df
else:
sim_trial = none_throws(
backend_simulator.get_sim_trial_by_index(trial.index)
)
# The BackendSimulator distinguishes between queued and running
# trials "for testing particular initialization cases", but these
# are all "running" to Scheduler.
# start_time = none_throws(sim_trial.sim_queued_time)
start_time = none_throws(sim_trial.sim_start_time)

if sim_trial.sim_completed_time is None: # Still running
max_t = backend_simulator.time - start_time
elif sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
else:
if sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
# Completed, may have stopped early
max_t = none_throws(sim_trial.sim_completed_time) - start_time

df = (
metadata.dfs[self.name]
.loc[lambda x: x["virtual runtime"] <= max_t]
.drop(columns=["virtual runtime"])
.reset_index(drop=True)
# Just in case the key was renamed by a subclass
.rename(columns={"step": self.map_key_info.key})
)
# Completed, may have stopped early -- can't assume all data available
completed_time = none_throws(sim_trial.sim_completed_time)
max_t = completed_time - start_time

available_data = df[df["virtual runtime"] <= max_t]

# MapData case: return all available data
if isinstance(self, MapMetric):
df = (
available_data
# Just in case the key was renamed by a subclass
.rename(columns={"step": self.map_key_info.key})
)
else:
# Non-MapData case: return the most recent data
df = available_data.loc[
available_data["virtual runtime"]
== available_data["virtual runtime"].max()
].drop(columns=["step"])
df = df.drop(columns=["virtual runtime"]).reset_index(drop=True)
if not self.observe_noise_sd:
df["sem"] = None
if isinstance(self, MapMetric):
return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))
return Ok(value=Data(df=df))


class BenchmarkMetric(BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is not available while
running.
"""

def _class_specific_metadata_validation(
self, metadata: BenchmarkTrialMetadata, trial_index: int
) -> None:
df = metadata.dfs[self.name]
if df["step"].nunique() > 1:
raise ValueError(
f"Trial {trial_index} has data from multiple time steps. This is"
f" not supported by `{self.__class__.__name__}`; use "
"`BenchmarkMapMetric`."
)


return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))
class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is available while
running.
"""

# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

@classmethod
def is_available_while_running(cls) -> bool:
return True
9 changes: 8 additions & 1 deletion ax/benchmark/tests/test_benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def setUp(self) -> None:
for name in self.outcome_names
)

def test_available_while_running(self) -> None:
self.assertFalse(self.metric1.is_available_while_running())
self.assertFalse(BenchmarkMetric.is_available_while_running())
self.assertTrue(self.map_metric1.is_available_while_running())
self.assertTrue(BenchmarkMapMetric.is_available_while_running())

def test_fetch_trial_data(self) -> None:
with self.subTest("Error for multiple metrics in BenchmarkMetric"):
trial = get_test_trial()
Expand Down Expand Up @@ -251,7 +257,8 @@ def test_sim_trial_completes_in_future_raises(self) -> None:
simulator.update()
simulator.options.internal_clock = -1
metadata = BenchmarkTrialMetadata(
dfs={"test_metric": pd.DataFrame({"t": [3]})}, backend_simulator=simulator
dfs={"test_metric": pd.DataFrame({"t": [3], "step": 0})},
backend_simulator=simulator,
)
trial = Mock(spec=Trial)
trial.index = 0
Expand Down

0 comments on commit 63ef940

Please sign in to comment.