diff --git a/ax/benchmark/benchmark_metric.py b/ax/benchmark/benchmark_metric.py index a92cf2f49bd..33370979791 100644 --- a/ax/benchmark/benchmark_metric.py +++ b/ax/benchmark/benchmark_metric.py @@ -5,8 +5,94 @@ # pyre-strict +""" +Metric classes for Ax benchmarking. + +Metrics vary on two dimensions: Whether they are `MapMetric`s or not, and +whether they are available while running or not. + +There are two Metric classes: +- `BenchmarkMetric`: For when outputs should be `Data` (not `MapData`) and data + is not available while running. +- `BenchmarkMapMetric`: For when outputs should be `MapData` (not `Data`) and + data is available while running. + +There are further benchmark classes that are not yet implemented: +- `BenchmarkTimeVaryingMetric`: For when outputs should be `Data` and the metric + is available while running. +- `BenchmarkMapUnavailableWhileRunningMetric`: For when outputs should be + `MapData` and the metric is not available while running. + +Any of these can be used with or without a simulator. However, +`BenchmarkMetric.fetch_trial_data` cannot take in data with multiple time steps, +as they will not be used and this is assumed to be an error. The below table +enumerates use cases. + +.. list-table:: Benchmark Metrics Table + :widths: 5 25 5 5 5 50 + :header-rows: 1 + + * - + - Metric + - Map + - Available while running + - Simulator + - Reason/use case + * - 1 + - BenchmarkMetric + - No + - No + - No + - Vanilla + * - 2 + - BenchmarkMetric + - No + - No + - Yes + - Asynchronous, data read only at end + * - 3 + - BenchmarkTimeVaryingMetric + - No + - Yes + - No + - Behaves like #1 because it will never be RUNNING + * - 4 + - BenchmarkTimeVaryingMetric + - No + - Yes + - Yes + - Scalar data that changes over time + * - 5 + - BenchmarkMapUnavailableWhileRunningMetric + - Yes + - No + - No + - MapData that returns immediately; could be used for getting baseline + * - 6 + - BenchmarkMapUnavailableWhileRunningMetric + - Yes + - No + - Yes + - Asynchronicity with MapData read only at end + * - 7 + - BenchmarkMapMetric + - Yes + - Yes + - No + - Behaves same as #5 + * - 8 + - BenchmarkMapMetric + - Yes + - Yes + - Yes + - Early stopping +""" + +from abc import abstractmethod from typing import Any +from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata + from ax.core.base_trial import BaseTrial from ax.core.batch_trial import BatchTrial from ax.core.data import Data @@ -15,6 +101,7 @@ from ax.core.map_metric import MapMetric from ax.core.metric import Metric, MetricFetchE, MetricFetchResult from ax.utils.common.result import Err, Ok +from pandas import DataFrame from pyre_extensions import none_throws @@ -22,40 +109,7 @@ def _get_no_metadata_msg(trial_index: int) -> str: return f"No metadata available for trial {trial_index}." -def _get_no_metadata_err(trial: BaseTrial) -> Err[Data, MetricFetchE]: - return Err( - MetricFetchE( - message=_get_no_metadata_msg(trial_index=trial.index), - exception=None, - ) - ) - - -def _validate_trial_and_kwargs( - trial: BaseTrial, class_name: str, **kwargs: Any -) -> None: - """ - Validate that: - - Kwargs are empty - - No arms within a BatchTrial have been abandoned - """ - if len(kwargs) > 0: - raise NotImplementedError( - f"Arguments {set(kwargs)} are not supported in " - f"{class_name}.fetch_trial_data." - ) - if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0: - raise NotImplementedError( - "BenchmarkMetric does not support abandoned arms in batch trials." - ) - - -class BenchmarkMetric(Metric): - """A generic metric used for observed values produced by Ax Benchmarks. - - Compatible with results generated by `BenchmarkRunner`. - """ - +class BenchmarkMetricBase(Metric): def __init__( self, name: str, @@ -77,72 +131,13 @@ def __init__( self.lower_is_better: bool = lower_is_better self.observe_noise_sd: bool = observe_noise_sd - def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: - """ - Args: - trial: The trial from which to fetch data. - kwargs: Unsupported and will raise an exception. - - Returns: - A MetricFetchResult containing the data for the requested metric. - """ - _validate_trial_and_kwargs( - trial=trial, class_name=self.__class__.__name__, **kwargs - ) - if len(trial.run_metadata) == 0: - return _get_no_metadata_err(trial=trial) - - df = trial.run_metadata["benchmark_metadata"].dfs[self.name] - if df["step"].nunique() > 1: - raise ValueError( - f"Trial {trial.index} has data from multiple time steps. This is" - " not supported by `BenchmarkMetric`; use `BenchmarkMapMetric`." - ) - df = df.drop(columns=["step", "virtual runtime"]) - if not self.observe_noise_sd: - df["sem"] = None - return Ok(value=Data(df=df)) - - -class BenchmarkMapMetric(MapMetric): - # pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute - # defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a - # subtype of the overridden attribute `MapKeyInfo[float]` - map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0) - - def __init__( - self, - name: str, - # Needed to be boolean (not None) for validation of MOO opt configs - lower_is_better: bool, - observe_noise_sd: bool = True, + def _class_specific_metdata_validation( + self, metadata: BenchmarkTrialMetadata | None ) -> None: - """ - Args: - name: Name of the metric. - lower_is_better: If `True`, lower metric values are considered better. - observe_noise_sd: If `True`, the standard deviation of the observation - noise is included in the `sem` column of the the returned data. - If `False`, `sem` is set to `None` (meaning that the model will - have to infer the noise level). - """ - super().__init__(name=name, lower_is_better=lower_is_better) - # Declare `lower_is_better` as bool (rather than optional as in the base class) - self.lower_is_better: bool = lower_is_better - self.observe_noise_sd: bool = observe_noise_sd - - @classmethod - def is_available_while_running(cls) -> bool: - return True + return def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: """ - If the trial has been completed, look up the ``sim_start_time`` and - ``sim_completed_time`` on the corresponding ``SimTrial``, and return all - data from keys 0, ..., ``sim_completed_time - sim_start_time``. If the - trial has not completed, return all data from keys 0, ..., ``sim_runtime - - sim_start_time``. - Args: trial: The trial from which to fetch data. kwargs: Unsupported and will raise an exception. @@ -150,18 +145,36 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult Returns: A MetricFetchResult containing the data for the requested metric. """ - _validate_trial_and_kwargs( - trial=trial, class_name=self.__class__.__name__, **kwargs - ) + + class_name = self.__class__.__name__ + if len(kwargs) > 0: + raise NotImplementedError( + f"Arguments {set(kwargs)} are not supported in " + f"{class_name}.fetch_trial_data." + ) + if isinstance(trial, BatchTrial) and len(trial.abandoned_arms) > 0: + raise NotImplementedError( + f"{self.__class__.__name__} does not support abandoned arms in " + "batch trials." + ) if len(trial.run_metadata) == 0: - return _get_no_metadata_err(trial=trial) + return Err( + MetricFetchE( + message=_get_no_metadata_msg(trial_index=trial.index), + exception=None, + ) + ) metadata = trial.run_metadata["benchmark_metadata"] - + self._class_specific_metdata_validation(metadata=metadata) backend_simulator = metadata.backend_simulator + df = metadata.dfs[self.name] + # Filter out the observable data if backend_simulator is None: - max_t = float("inf") + # If there's no backend simulator then no filtering is needed; the + # trial will complete immediately, with all data available. + available_data = df else: sim_trial = none_throws( backend_simulator.get_sim_trial_by_index(trial.index) @@ -169,30 +182,74 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult # The BackendSimulator distinguishes between queued and running # trials "for testing particular initialization cases", but these # are all "running" to Scheduler. - # start_time = none_throws(sim_trial.sim_queued_time) start_time = none_throws(sim_trial.sim_start_time) if sim_trial.sim_completed_time is None: # Still running max_t = backend_simulator.time - start_time + elif sim_trial.sim_completed_time > backend_simulator.time: + raise RuntimeError( + "The trial's completion time is in the future! This is " + f"unexpected. {sim_trial.sim_completed_time=}, " + f"{backend_simulator.time=}" + ) else: - if sim_trial.sim_completed_time > backend_simulator.time: - raise RuntimeError( - "The trial's completion time is in the future! This is " - f"unexpected. {sim_trial.sim_completed_time=}, " - f"{backend_simulator.time=}" - ) - # Completed, may have stopped early - max_t = none_throws(sim_trial.sim_completed_time) - start_time - - df = ( - metadata.dfs[self.name] - .loc[lambda x: x["virtual runtime"] <= max_t] - .drop(columns=["virtual runtime"]) - .reset_index(drop=True) - # Just in case the key was renamed by a subclass - .rename(columns={"step": self.map_key_info.key}) - ) + # Completed, may have stopped early -- can't assume all data available + completed_time = none_throws(sim_trial.sim_completed_time) + max_t = completed_time - start_time + + available_data = df[df["virtual runtime"] <= max_t] + if not self.observe_noise_sd: - df["sem"] = None + available_data["sem"] = None + return self._df_to_result(df=available_data.drop(columns=["virtual runtime"])) + + @abstractmethod + def _df_to_result(self, df: DataFrame) -> MetricFetchResult: + """ + Convert a DataFrame of observable data to Data or MapData, as + appropriate for the class. + """ + ... + + +class BenchmarkMetric(BenchmarkMetricBase): + """ + Metric for benchmarking that produces `Data` and is not available while + running. + """ + + def _class_specific_metdata_validation( + self, metadata: BenchmarkTrialMetadata | None + ) -> None: + if metadata is not None: + df = metadata.dfs[self.name] + if df["step"].nunique() > 1: + raise ValueError( + f"Trial has data from multiple time steps. This is" + f" not supported by `{self.__class__.__name__}`; use " + "`BenchmarkMapMetric`." + ) + + def _df_to_result(self, df: DataFrame) -> MetricFetchResult: + return Ok(value=Data(df=df.drop(columns=["step"]))) + + +class BenchmarkMapMetric(MapMetric, BenchmarkMetricBase): + """ + Metric for benchmarking that produces `Data` and is available while + running. + """ + + # pyre-fixme: Inconsistent override [15]: `map_key_info` overrides attribute + # defined in `MapMetric` inconsistently. Type `MapKeyInfo[int]` is not a + # subtype of the overridden attribute `MapKeyInfo[float]` + map_key_info: MapKeyInfo[int] = MapKeyInfo(key="step", default_value=0) + + @classmethod + def is_available_while_running(cls) -> bool: + return True + def _df_to_result(self, df: DataFrame) -> MetricFetchResult: + # Just in case the key was renamed by a subclass + df = df.rename(columns={"step": self.map_key_info.key}) return Ok(value=MapData(df=df, map_key_infos=[self.map_key_info])) diff --git a/ax/benchmark/tests/test_benchmark_metric.py b/ax/benchmark/tests/test_benchmark_metric.py index 9ce986312ee..514074eeebf 100644 --- a/ax/benchmark/tests/test_benchmark_metric.py +++ b/ax/benchmark/tests/test_benchmark_metric.py @@ -116,7 +116,7 @@ def test_fetch_trial_data(self) -> None: "benchmark_metadata" ] with self.assertRaisesRegex( - ValueError, "Trial 0 has data from multiple time steps" + ValueError, "Trial has data from multiple time steps" ): self.metric1.fetch_trial_data(trial=trial)