Skip to content

Commit

Permalink
Refactor benchmark metric classes to increase code sharing (facebook#…
Browse files Browse the repository at this point in the history
…3182)

Summary:

**Context**: The two benchmark metric classes, `BenchmarkMetric` and `BenchmarkMapMetric`, share some code that is currently duplicated. I initially didn't want to have them inherit from a common `BenchmarkMetric` class, since this would cause a potential diamond inheritance issue in `BenchmarkMapMetric`. However, I don't see this issue as very risky, and introducing a base class will make it easy to add  classes we don't currently have: a non-Map metric that is available while running, and a Map metric that is not available while running.

Note that there are exactly four possible benchmark metric classes (map data vs. not, available while running vs. not) and these cannot be consolidated into fewer classes, since metrics must inherit from MapMetric if and only if they produce MapData, and `available_while_running` is a class method.

**This diff**:
* Introduces a base class `BenchmarkMetricBase`, consolidates logic into it, and makes `BenchmarkMetric` and `BenchmarkMapMetric` inherit from it

`BenchmarkMetricBase.fetch_trial_data` may appear unnecessarily complex for the two classes we have now, but will not require any changes to add two more classes.

Reviewed By: Balandat

Differential Revision: D67254165
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 18, 2024
1 parent ceb07f5 commit 359b656
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 122 deletions.
299 changes: 178 additions & 121 deletions ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,94 @@

# pyre-strict

"""
Metric classes for Ax benchmarking.
Metrics vary on two dimensions: Whether they are `MapMetric`s or not, and
whether they are available while running or not.
There are two Metric classes:
- `BenchmarkMetric`: For when outputs should be `Data` (not `MapData`) and data
is not available while running.
- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and
data is available while running.
There are further benchmark classes that are not yet implemented:
- `BenchmarkTimeVaryingMetric`: For when outputs should be `Data` and the metric
is available while running.
- `BenchmarkMapUnavailableWhileRunningMetric`: For when outputs should be
`MapData` and the metric is not available while running.
Any of these can be used with or without a simulator. However,
`BenchmarkMetric.fetch_trial_data` cannot take in data with multiple time steps,
as they will not be used and this is assumed to be an error. The below table
enumerates use cases.
.. list-table:: Benchmark Metrics Table
:widths: 5 25 5 5 5 50
:header-rows: 1
* -
- Metric
- Map
- Available while running
- Simulator
- Reason/use case
* - 1
- BenchmarkMetric
- No
- No
- No
- Vanilla
* - 2
- BenchmarkMetric
- No
- No
- Yes
- Asynchronous, data read only at end
* - 3
- BenchmarkTimeVaryingMetric
- No
- Yes
- No
- Behaves like #1 because it will never be RUNNING
* - 4
- BenchmarkTimeVaryingMetric
- No
- Yes
- Yes
- Scalar data that changes over time
* - 5
- BenchmarkMapUnavailableWhileRunningMetric
- Yes
- No
- No
- MapData that returns immediately; could be used for getting baseline
* - 6
- BenchmarkMapUnavailableWhileRunningMetric
- Yes
- No
- Yes
- Asynchronicity with MapData read only at end
* - 7
- BenchmarkMapMetric
- Yes
- Yes
- No
- Behaves same as #5
* - 8
- BenchmarkMapMetric
- Yes
- Yes
- Yes
- Early stopping
"""

from abc import abstractmethod
from typing import Any

from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata

from ax.core.base_trial import BaseTrial
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
Expand All @@ -15,47 +101,15 @@
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok
from pandas import DataFrame
from pyre_extensions import none_throws


def _get_no_metadata_msg(trial_index: int) -> str:
return f"No metadata available for trial {trial_index}."


def _get_no_metadata_err(trial: BaseTrial) -> Err[Data, MetricFetchE]:
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)


def _validate_trial_and_kwargs(
trial: BaseTrial, class_name: str, **kwargs: Any
) -> None:
"""
Validate that:
- Kwargs are empty
- No arms within a BatchTrial have been abandoned
"""
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
"BenchmarkMetric does not support abandoned arms in batch trials."
)


class BenchmarkMetric(Metric):
"""A generic metric used for observed values produced by Ax Benchmarks.
Compatible with results generated by `BenchmarkRunner`.
"""

class BenchmarkMetricBase(Metric):
def __init__(
self,
name: str,
Expand All @@ -77,122 +131,125 @@ def __init__(
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)

df = trial.run_metadata["benchmark_metadata"].dfs[self.name]
if df["step"].nunique() > 1:
raise ValueError(
f"Trial {trial.index} has data from multiple time steps. This is"
" not supported by `BenchmarkMetric`; use `BenchmarkMapMetric`."
)
df = df.drop(columns=["step", "virtual runtime"])
if not self.observe_noise_sd:
df["sem"] = None
return Ok(value=Data(df=df))


class BenchmarkMapMetric(MapMetric):
# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

def __init__(
self,
name: str,
# Needed to be boolean (not None) for validation of MOO opt configs
lower_is_better: bool,
observe_noise_sd: bool = True,
def _class_specific_metdata_validation(
self, metadata: BenchmarkTrialMetadata | None
) -> None:
"""
Args:
name: Name of the metric.
lower_is_better: If `True`, lower metric values are considered better.
observe_noise_sd: If `True`, the standard deviation of the observation
noise is included in the `sem` column of the the returned data.
If `False`, `sem` is set to `None` (meaning that the model will
have to infer the noise level).
"""
super().__init__(name=name, lower_is_better=lower_is_better)
# Declare `lower_is_better` as bool (rather than optional as in the base class)
self.lower_is_better: bool = lower_is_better
self.observe_noise_sd: bool = observe_noise_sd

@classmethod
def is_available_while_running(cls) -> bool:
return True
return

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
If the trial has been completed, look up the ``sim_start_time`` and
``sim_completed_time`` on the corresponding ``SimTrial``, and return all
data from keys 0, ..., ``sim_completed_time - sim_start_time``. If the
trial has not completed, return all data from keys 0, ..., ``sim_runtime
- sim_start_time``.
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
_validate_trial_and_kwargs(
trial=trial, class_name=self.__class__.__name__, **kwargs
)

class_name = self.__class__.__name__
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{class_name}.fetch_trial_data."
)
if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0:
raise NotImplementedError(
f"{self.__class__.__name__} does not support abandoned arms in "
"batch trials."
)
if len(trial.run_metadata) == 0:
return _get_no_metadata_err(trial=trial)
return Err(
MetricFetchE(
message=_get_no_metadata_msg(trial_index=trial.index),
exception=None,
)
)

metadata = trial.run_metadata["benchmark_metadata"]

self._class_specific_metdata_validation(metadata=metadata)
backend_simulator = metadata.backend_simulator
df = metadata.dfs[self.name]

# Filter out the observable data
if backend_simulator is None:
max_t = float("inf")
# If there's no backend simulator then no filtering is needed; the
# trial will complete immediately, with all data available.
available_data = df
else:
sim_trial = none_throws(
backend_simulator.get_sim_trial_by_index(trial.index)
)
# The BackendSimulator distinguishes between queued and running
# trials "for testing particular initialization cases", but these
# are all "running" to Scheduler.
# start_time = none_throws(sim_trial.sim_queued_time)
start_time = none_throws(sim_trial.sim_start_time)

if sim_trial.sim_completed_time is None: # Still running
max_t = backend_simulator.time - start_time
elif sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
else:
if sim_trial.sim_completed_time > backend_simulator.time:
raise RuntimeError(
"The trial's completion time is in the future! This is "
f"unexpected. {sim_trial.sim_completed_time=}, "
f"{backend_simulator.time=}"
)
# Completed, may have stopped early
max_t = none_throws(sim_trial.sim_completed_time) - start_time

df = (
metadata.dfs[self.name]
.loc[lambda x: x["virtual runtime"] <= max_t]
.drop(columns=["virtual runtime"])
.reset_index(drop=True)
# Just in case the key was renamed by a subclass
.rename(columns={"step": self.map_key_info.key})
)
# Completed, may have stopped early -- can't assume all data available
completed_time = none_throws(sim_trial.sim_completed_time)
max_t = completed_time - start_time

available_data = df[df["virtual runtime"] <= max_t]

if not self.observe_noise_sd:
df["sem"] = None
available_data["sem"] = None
return self._df_to_result(df=available_data.drop(columns=["virtual runtime"]))

@abstractmethod
def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
"""
Convert a DataFrame of observable data to Data or MapData, as
appropriate for the class.
"""
...


class BenchmarkMetric(BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is not available while
running.
"""

def _class_specific_metdata_validation(
self, metadata: BenchmarkTrialMetadata | None
) -> None:
if metadata is not None:
df = metadata.dfs[self.name]
if df["step"].nunique() > 1:
raise ValueError(
f"Trial has data from multiple time steps. This is"
f" not supported by `{self.__class__.__name__}`; use "
"`BenchmarkMapMetric`."
)

def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
return Ok(value=Data(df=df.drop(columns=["step"])))


class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase):
"""
Metric for benchmarking that produces `Data` and is available while
running.
"""

# pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute
# defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a
# subtype of the overridden attribute `MapKeyInfo[float]`
map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0)

@classmethod
def is_available_while_running(cls) -> bool:
return True

def _df_to_result(self, df: DataFrame) -> MetricFetchResult:
# Just in case the key was renamed by a subclass
df = df.rename(columns={"step": self.map_key_info.key})
return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info]))
2 changes: 1 addition & 1 deletion ax/benchmark/tests/test_benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_fetch_trial_data(self) -> None:
"benchmark_metadata"
]
with self.assertRaisesRegex(
ValueError, "Trial 0 has data from multiple time steps"
ValueError, "Trial has data from multiple time steps"
):
self.metric1.fetch_trial_data(trial=trial)

Expand Down

0 comments on commit 359b656

Please sign in to comment.