Skip to content

Commit

Permalink
Refactor benchmark metric classes to increase code sharing (facebook#…
Browse files Browse the repository at this point in the history
…3182)

Summary:

**Context**: The two benchmark metric classes, `BenchmarkMetric` and `BenchmarkMapMetric`, share some code that is currently duplicated. I initially didn't want to have them inherit from a common `BenchmarkMetric` class, since this would cause a potential diamond inheritance issue in `BenchmarkMapMetric`. However, I don't see this issue as very risky, and introducing a base class will make it easy to add  classes we don't currently have: a non-Map metric that is available while running, and a Map metric that is not available while running.

Note that there are exactly four possible benchmark metric classes (map data vs. not, available while running vs. not) and these cannot be consolidated into fewer classes, since metrics must inherit from MapMetric if and only if they produce MapData, and `available_while_running` is a class method.

**This diff**:
* Introduces a base class `BenchmarkMetricBase`, consolidates logic into it, and makes `BenchmarkMetric` and `BenchmarkMapMetric` inherit from it

`BenchmarkMetricBase.fetch_trial_data` may appear unnecessarily complex for the two classes we have now, but will not require any changes to add two more classes.

Differential Revision: D67254165
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 17, 2024
1 parent c9303da commit 2d0c33b
Showing 1 changed file with 91 additions and 121 deletions.
212 changes: 91 additions & 121 deletions ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@

# pyre-strict

"""
Metric classes for Ax benchmarking.
There are two Metric classes:
- `BenchmarkMetric`: For when outputs should be `Data` (not `MapData`) and data
is not available while running.
- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and
data is available while running.
"""

from abc import abstractmethod
from typing import Any

from ax.core.base_trial import BaseTrial
Expand All @@ -15,47 +26,15 @@
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from pandas import DataFrame
from pyre_extensions import none_throws


def _get_no_metadata_msg(trial_index: int) -> str:
return f"No metadata available for trial {trial_index}."


def _get_no_metadata_err(trial: BaseTrial) -> Err[Data, MetricFetchE]:
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)


def _validate_trial_and_kwargs(
trial: BaseTrial, class_name: str, **kwargs: Any
) -> None:
"""
Validate that:
- Kwargs are empty
- No arms within a BatchTrial have been abandoned
"""
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
"BenchmarkMetric does not support abandoned arms in batch trials."
)


class BenchmarkMetric(Metric):
"""A generic metric used for observed values produced by Ax Benchmarks.
Compatible with results generated by `BenchmarkRunner`.
"""

class BenchmarkMetricBase(Metric):
def __init__(
self,
name: str,
Expand Down Expand Up @@ -86,113 +65,104 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)

df = trial.run_metadata["benchmark_metadata"].dfs[self.name]
if df["step"].nunique() > 1:
raise ValueError(
f"Trial {trial.index} has data from multiple time steps. This is"
" not supported by `BenchmarkMetric`; use `BenchmarkMapMetric`."
class_name = self.__class__.__name__
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
f"{self.__class__.__name__} does not support abandoned arms in "
"batch trials."
)
df = df.drop(columns=["step", "virtual runtime"])
if not self.observe_noise_sd:
df["sem"] = None
return Ok(value=Data(df=df))


class BenchmarkMapMetric(MapMetric):
# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

def __init__(
self,
name: str,
# Needed to be boolean (not None) for validation of MOO opt configs
lower_is_better: bool,
observe_noise_sd: bool = True,
) -> None:
"""
Args:
name: Name of the metric.
lower_is_better: If `True`, lower metric values are considered better.
observe_noise_sd: If `True`, the standard deviation of the observation
noise is included in the `sem` column of the the returned data.
If `False`, `sem` is set to `None` (meaning that the model will
have to infer the noise level).
"""
super().__init__(name=name, lower_is_better=lower_is_better)
# Declare `lower_is_better` as bool (rather than optional as in the base class)
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd

@classmethod
def is_available_while_running(cls) -> bool:
return True

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
If the trial has been completed, look up the ``sim_start_time`` and
``sim_completed_time`` on the corresponding ``SimTrial``, and return all
data from keys 0, ..., ``sim_completed_time - sim_start_time``. If the
trial has not completed, return all data from keys 0, ..., ``sim_runtime
- sim_start_time``.
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)

metadata = trial.run_metadata["benchmark_metadata"]

backend_simulator = metadata.backend_simulator
df = metadata.dfs[self.name]

# Filter out the observable data
if backend_simulator is None:
max_t = float("inf")
# If there's no backend simulator then no filtering is needed; the
# trial will complete immediately, with all data available.
available_data = df
else:
sim_trial = none_throws(
backend_simulator.get_sim_trial_by_index(trial.index)
)
# The BackendSimulator distinguishes between queued and running
# trials "for testing particular initialization cases", but these
# are all "running" to Scheduler.
# start_time = none_throws(sim_trial.sim_queued_time)
start_time = none_throws(sim_trial.sim_start_time)

if sim_trial.sim_completed_time is None: # Still running
max_t = backend_simulator.time - start_time
elif sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
else:
if sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
# Completed, may have stopped early
max_t = none_throws(sim_trial.sim_completed_time) - start_time

df = (
metadata.dfs[self.name]
.loc[lambda x: x["virtual runtime"] <= max_t]
.drop(columns=["virtual runtime"])
.reset_index(drop=True)
# Just in case the key was renamed by a subclass
.rename(columns={"step": self.map_key_info.key})
)
# Completed, may have stopped early -- can't assume all data available
completed_time = none_throws(sim_trial.sim_completed_time)
max_t = completed_time - start_time

available_data = df[df["virtual runtime"] <= max_t]

if not self.observe_noise_sd:
df["sem"] = None
available_data["sem"] = None
return self._df_to_result(df=available_data.drop(columns=["virtual runtime"]))

@abstractmethod
def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
"""
Convert a DataFrame of observable data to Data or MapData, as
appropriate for the class.
"""
...


class BenchmarkMetric(BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is not available while
running.
"""

def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
if df["step"].nunique() > 1:
raise ValueError(
f"Trial has data from multiple time steps. This is"
f" not supported by `{self.__class__.__name__}`; use "
"`BenchmarkMapMetric`."
)
return Ok(value=Data(df=df.drop(columns=["step"])))


class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is available while
running.
"""

# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

@classmethod
def is_available_while_running(cls) -> bool:
return True

def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
# Just in case the key was renamed by a subclass
df = df.rename(columns={"step": self.map_key_info.key})
return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))

0 comments on commit 2d0c33b

Please sign in to comment.