From 666df783effa4024eb32367b2b5da7a31f542d8f Mon Sep 17 00:00:00 2001 From: Paschal Igusti Date: Thu, 12 Dec 2024 17:24:55 -0800 Subject: [PATCH] Move get_improvement_over_baseline to the BestPointMixin (#3156) Summary: As titled - in this change we: 1. Move get_improvement_over_baseline to the BestPointMixin, and required an experiment and generation strategy 1. remove the required parameter of `baseline_arm_name` and utilize `select_baseline_arm()` if not provided 2. Introduce `best_point_utils.py` file to store `select_baseline_arm()` 3. Modified `select_baseline_arm()` to take in a dictionary of arm_names to Arms instead of a experiment df Reviewed By: lena-kashtelyan Differential Revision: D66472613 --- ax/service/scheduler.py | 65 --------- ax/service/tests/scheduler_test_utils.py | 36 ++--- ax/service/tests/test_best_point_utils.py | 77 +++++++++- ax/service/tests/test_report_utils.py | 163 ++++++++-------------- ax/service/utils/best_point_mixin.py | 72 ++++++++++ ax/service/utils/best_point_utils.py | 54 +++++++ ax/service/utils/report_utils.py | 49 +------ sphinx/source/service.rst | 6 + 8 files changed, 291 insertions(+), 231 deletions(-) create mode 100644 ax/service/utils/best_point_utils.py diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 1bbed5705bf..502cd523586 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1021,71 +1021,6 @@ def summarize_final_result(self) -> OptimizationResult: """ return OptimizationResult() - def get_improvement_over_baseline( - self, - baseline_arm_name: str | None = None, - ) -> float: - """Returns the scalarized improvement over baseline, if applicable. - - Returns: - For Single Objective cases, returns % improvement of objective. - Positive indicates improvement over baseline. Negative indicates regression. - For Multi Objective cases, throws NotImplementedError - """ - if self.experiment.is_moo_problem: - raise NotImplementedError( - "`get_improvement_over_baseline` not yet implemented" - + " for multi-objective problems." - ) - if not baseline_arm_name: - raise UserInputError( - "`get_improvement_over_baseline` missing required parameter: " - + f"{baseline_arm_name=}, " - ) - - optimization_config = self.experiment.optimization_config - if not optimization_config: - raise ValueError("No optimization config found.") - - objective_metric_name = optimization_config.objective.metric.name - - # get the baseline trial - data = self.experiment.lookup_data().df - data = data[data["arm_name"] == baseline_arm_name] - if len(data) == 0: - raise UserInputError( - "`get_improvement_over_baseline`" - " could not find baseline arm" - f" `{baseline_arm_name}` in the experiment data." - ) - data = data[data["metric_name"] == objective_metric_name] - baseline_value = data.iloc[0]["mean"] - - # Find objective value of the best trial - idx, param, best_arm = none_throws( - self.get_best_trial( - optimization_config=optimization_config, use_model_predictions=False - ) - ) - best_arm = none_throws(best_arm) - best_obj_value = best_arm[0][objective_metric_name] - - def percent_change(x: float, y: float, minimize: bool) -> float: - if x == 0: - raise ZeroDivisionError( - "Cannot compute percent improvement when denom is zero" - ) - percent_change = (y - x) / abs(x) * 100 - if minimize: - percent_change = -percent_change - return percent_change - - return percent_change( - x=baseline_value, - y=best_obj_value, - minimize=optimization_config.objective.minimize, - ) - def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool: """Checks if the failure rate (set in scheduler options) has been exceeded at any point during the optimization. diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index b9eb7e0be9e..99edccfdeda 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -2197,6 +2197,8 @@ def test_get_improvement_over_baseline(self) -> None: scheduler.experiment.trials[0].lookup_data().df["arm_name"].iloc[0] ) percent_improvement = scheduler.get_improvement_over_baseline( + experiment=scheduler.experiment, + generation_strategy=scheduler.standard_generation_strategy, baseline_arm_name=first_trial_name, ) @@ -2209,11 +2211,7 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None: self.branin_experiment.optimization_config = ( get_branin_multi_objective_optimization_config() ) - - gs = self._get_generation_strategy_strategy_for_test( - experiment=self.branin_experiment, - generation_strategy=self.sobol_MBM_GS, - ) + gs = self.sobol_MBM_GS scheduler = Scheduler( experiment=self.branin_experiment, @@ -2227,6 +2225,8 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None: with self.assertRaises(NotImplementedError): scheduler.get_improvement_over_baseline( + experiment=scheduler.experiment, + generation_strategy=scheduler.standard_generation_strategy, baseline_arm_name=None, ) @@ -2236,10 +2236,7 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None experiment.name = f"{self.branin_experiment.name}_but_moo" experiment.runner = self.runner - gs = self._get_generation_strategy_strategy_for_test( - experiment=experiment, - generation_strategy=self.two_sobol_steps_GS, - ) + gs = self.two_sobol_steps_GS scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. generation_strategy=gs, @@ -2251,8 +2248,10 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None db_settings=self.db_settings_if_always_needed, ) - with self.assertRaises(UserInputError): + with self.assertRaises(ValueError): scheduler.get_improvement_over_baseline( + experiment=scheduler.experiment, + generation_strategy=scheduler.standard_generation_strategy, baseline_arm_name=None, ) @@ -2267,19 +2266,20 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None scheduler.experiment = exp_copy with self.assertRaises(ValueError): - scheduler.get_improvement_over_baseline(baseline_arm_name="baseline") + scheduler.get_improvement_over_baseline( + experiment=scheduler.experiment, + generation_strategy=scheduler.standard_generation_strategy, + baseline_arm_name="baseline", + ) def test_get_improvement_over_baseline_no_baseline(self) -> None: """Test that get_improvement_over_baseline returns UserInputError when baseline is not found in data.""" n_total_trials = 8 - gs = self._get_generation_strategy_strategy_for_test( - experiment=self.branin_experiment, - generation_strategy=self.two_sobol_steps_GS, - ) - + experiment = self.branin_experiment + gs = self.two_sobol_steps_GS scheduler = Scheduler( - experiment=self.branin_experiment, # Has runner and metrics. + experiment=experiment, # Has runner and metrics. generation_strategy=gs, options=SchedulerOptions( total_trials=n_total_trials, @@ -2293,6 +2293,8 @@ def test_get_improvement_over_baseline_no_baseline(self) -> None: with self.assertRaises(UserInputError): scheduler.get_improvement_over_baseline( + experiment=experiment, + generation_strategy=gs, baseline_arm_name="baseline_arm_not_in_data", ) diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index 9672db9dcff..dae1ffe248d 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -6,15 +6,18 @@ # pyre-strict +import copy import random -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, PropertyMock import pandas as pd import torch from ax.core.arm import Arm from ax.core.batch_trial import BatchTrial from ax.core.data import Data +from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun +from ax.core.metric import Metric from ax.core.objective import ScalarizedObjective from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import OutcomeConstraint @@ -32,10 +35,12 @@ get_best_raw_objective_point, logger as best_point_logger, ) +from ax.service.utils.best_point_utils import select_baseline_name_default_first_trial from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, get_branin_metric, + get_branin_search_space, get_experiment_with_observations, get_sobol, ) @@ -556,6 +561,76 @@ def test_is_row_feasible(self) -> None: df.index, feasible_series.index, check_names=False ) + def test_compare_to_baseline_select_baseline_name_default_first_trial(self) -> None: + OBJECTIVE_METRIC = "objective" + true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True) + experiment = Experiment( + search_space=get_branin_search_space(), + tracking_metrics=[true_obj_metric], + ) + + with patch.object( + Experiment, "arms_by_name", new_callable=PropertyMock + ) as mock_arms_by_name: + mock_arms_by_name.return_value = {"arm1": "value1", "arm2": "value2"} + self.assertEqual( + select_baseline_name_default_first_trial( + experiment=experiment, + baseline_arm_name="arm1", + ), + ("arm1", False), + ) + + # specified baseline arm not in trial + wrong_baseline_name = "wrong_baseline_name" + with self.assertRaisesRegex( + ValueError, + "Arm by name .*" + " not found.", + ): + select_baseline_name_default_first_trial( + experiment=experiment, + baseline_arm_name=wrong_baseline_name, + ) + + # status quo baseline arm + experiment_with_status_quo = copy.deepcopy(experiment) + experiment_with_status_quo.status_quo = Arm( + name="status_quo", + parameters={"x1": 0, "x2": 0}, + ) + self.assertEqual( + select_baseline_name_default_first_trial( + experiment=experiment_with_status_quo, + baseline_arm_name=None, + ), + ("status_quo", False), + ) + # first arm from trials + custom_arm = Arm(name="m_0", parameters={"x1": 0.1, "x2": 0.2}) + experiment.new_trial().add_arm(custom_arm) + self.assertEqual( + select_baseline_name_default_first_trial( + experiment=experiment, + baseline_arm_name=None, + ), + ("m_0", True), + ) + + # none selected + experiment_with_no_valid_baseline = Experiment( + search_space=get_branin_search_space(), + tracking_metrics=[true_obj_metric], + ) + + with self.assertRaisesRegex( + ValueError, + "Could not find valid baseline arm.", + ): + select_baseline_name_default_first_trial( + experiment=experiment_with_no_valid_baseline, + baseline_arm_name=None, + ) + def _repeat_elements(list_to_replicate: list[bool], n_repeats: int) -> pd.Series: return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)]) diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index 653693166b2..a69f2396fab 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -6,12 +6,11 @@ # pyre-strict -import copy import itertools from collections import namedtuple from logging import INFO, WARN from unittest import mock -from unittest.mock import patch +from unittest.mock import patch, PropertyMock import pandas as pd from ax.core.arm import Arm @@ -44,7 +43,6 @@ FEASIBLE_COL_NAME, get_standard_plots, plot_feature_importance_by_feature_plotly, - select_baseline_arm, warn_if_unpredictable_metrics, ) from ax.service.utils.scheduler_options import SchedulerOptions @@ -612,9 +610,21 @@ def test_compare_to_baseline(self) -> None: ] arms_df = pd.DataFrame(data) + arms_by_name_mock = { + BASELINE_ARM_NAME: Arm(name=BASELINE_ARM_NAME, parameters={}), + "dummy": Arm(name="dummy", parameters={}), + "optimal": Arm(name="optimal", parameters={}), + "bad_optimal": Arm(name="bad_optimal", parameters={}), + } + with patch( "ax.service.utils.report_utils.exp_to_df", return_value=arms_df, + ), patch.object( + Experiment, + "arms_by_name", + new_callable=PropertyMock, + return_value=arms_by_name_mock, ): true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=False) experiment = Experiment( @@ -677,9 +687,21 @@ def test_compare_to_baseline_pass_in_opt(self) -> None: ] arms_df = pd.DataFrame(data) + arms_by_name_mock = { + BASELINE_ARM_NAME: Arm(name=BASELINE_ARM_NAME, parameters={}), + "dummy": Arm(name="dummy", parameters={}), + "optimal": Arm(name="optimal", parameters={}), + "bad_optimal": Arm(name="bad_optimal", parameters={}), + } + with patch( "ax.service.utils.report_utils.exp_to_df", return_value=arms_df, + ), patch.object( + Experiment, + "arms_by_name", + new_callable=PropertyMock, + return_value=arms_by_name_mock, ): true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=False) experiment = Experiment( @@ -740,9 +762,21 @@ def test_compare_to_baseline_minimize(self) -> None: ] arms_df = pd.DataFrame(data) + arms_by_name_mock = { + custom_baseline_arm_name: Arm(name=custom_baseline_arm_name, parameters={}), + "dummy": Arm(name="dummy", parameters={}), + "optimal": Arm(name="optimal", parameters={}), + "bad_optimal": Arm(name="bad_optimal", parameters={}), + } + with patch( "ax.service.utils.report_utils.exp_to_df", return_value=arms_df, + ), patch.object( + Experiment, + "arms_by_name", + new_callable=PropertyMock, + return_value=arms_by_name_mock, ): true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True) experiment = Experiment( @@ -816,10 +850,21 @@ def test_compare_to_baseline_edge_case(self) -> None: {"trial_index": 1, "arm_name": "optimal", OBJECTIVE_METRIC: 1.0}, ] arms_df = pd.DataFrame(data) + arms_by_name_mock = { + BASELINE_ARM_NAME: Arm(name=BASELINE_ARM_NAME, parameters={}), + "dummy": Arm(name="dummy", parameters={}), + "optimal": Arm(name="optimal", parameters={}), + "bad_optimal": Arm(name="bad_optimal", parameters={}), + } with patch( "ax.service.utils.report_utils.exp_to_df", return_value=arms_df, + ), patch.object( + Experiment, + "arms_by_name", + new_callable=PropertyMock, + return_value=arms_by_name_mock, ): with self.assertLogs("ax", level=INFO) as log: self.assertEqual( @@ -998,11 +1043,7 @@ def test_compare_to_baseline_arms_not_found(self) -> None: ) self.assertTrue( any( - ( - f"compare_to_baseline: baseline row: {baseline_arm_name=}" - " not found in arms" - ) - in log_str + (f"Arm by name {baseline_arm_name=} not found.") in log_str for log_str in log.output ), log.output, @@ -1051,9 +1092,21 @@ def test_compare_to_baseline_moo(self) -> None: ] arms_df = pd.DataFrame(data) + arms_by_name_mock = { + BASELINE_ARM_NAME: Arm(name=BASELINE_ARM_NAME, parameters={}), + "dummy": Arm(name="dummy", parameters={}), + "optimal": Arm(name="optimal", parameters={}), + "bad_optimal": Arm(name="bad_optimal", parameters={}), + } + with patch( "ax.service.utils.report_utils.exp_to_df", return_value=arms_df, + ), patch.object( + Experiment, + "arms_by_name", + new_callable=PropertyMock, + return_value=arms_by_name_mock, ): m0 = Metric(name="m0", lower_is_better=False) m1 = Metric(name="m1", lower_is_better=True) @@ -1176,100 +1229,6 @@ def test_compare_to_baseline_equal(self) -> None: self.assertIsNone(result) - def test_compare_to_baseline_select_baseline_arm(self) -> None: - OBJECTIVE_METRIC = "objective" - true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True) - experiment = Experiment( - search_space=get_branin_search_space(), - tracking_metrics=[true_obj_metric], - ) - - # specified baseline - data = [ - { - "trial_index": 0, - "arm_name": "m_0", - OBJECTIVE_METRIC: 0.2, - }, - { - "trial_index": 1, - "arm_name": BASELINE_ARM_NAME, - OBJECTIVE_METRIC: 0.2, - }, - { - "trial_index": 2, - "arm_name": "status_quo", - OBJECTIVE_METRIC: 0.2, - }, - ] - arms_df = pd.DataFrame(data) - self.assertEqual( - select_baseline_arm( - experiment=experiment, - arms_df=arms_df, - baseline_arm_name=BASELINE_ARM_NAME, - ), - (BASELINE_ARM_NAME, False), - ) - - # specified baseline arm not in trial - wrong_baseline_name = "wrong_baseline_name" - with self.assertRaisesRegex( - ValueError, - "compare_to_baseline: baseline row: .*" + " not found in arms", - ): - select_baseline_arm( - experiment=experiment, - arms_df=arms_df, - baseline_arm_name=wrong_baseline_name, - ) - - # status quo baseline arm - experiment_with_status_quo = copy.deepcopy(experiment) - experiment_with_status_quo.status_quo = Arm( - name="status_quo", - parameters={"x1": 0, "x2": 0}, - ) - self.assertEqual( - select_baseline_arm( - experiment=experiment_with_status_quo, - arms_df=arms_df, - baseline_arm_name=None, - ), - ("status_quo", False), - ) - # first arm from trials - custom_arm = Arm(name="m_0", parameters={"x1": 0.1, "x2": 0.2}) - experiment.new_trial().add_arm(custom_arm) - self.assertEqual( - select_baseline_arm( - experiment=experiment, - arms_df=arms_df, - baseline_arm_name=None, - ), - ("m_0", True), - ) - - # none selected - experiment_with_no_valid_baseline = Experiment( - search_space=get_branin_search_space(), - tracking_metrics=[true_obj_metric], - ) - experiment_with_no_valid_baseline.status_quo = Arm( - name="not found", - parameters={"x1": 0, "x2": 0}, - ) - custom_arm = Arm(name="also not found", parameters={"x1": 0.1, "x2": 0.2}) - experiment_with_no_valid_baseline.new_trial().add_arm(custom_arm) - with self.assertRaisesRegex( - ValueError, "compare_to_baseline: could not find valid baseline arm" - ): - select_baseline_arm( - experiment=experiment_with_no_valid_baseline, - arms_df=arms_df, - baseline_arm_name=None, - ) - def test_warn_if_unpredictable_metrics(self) -> None: expected_msg = ( "The following metric(s) are behaving unpredictably and may be noisy or " diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index 5996bac164c..125b7f728ca 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -21,6 +21,7 @@ OptimizationConfig, ) from ax.core.types import TModelPredictArm, TParameterization +from ax.exceptions.core import UserInputError from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import ( extract_objective_thresholds, @@ -43,6 +44,7 @@ extract_Y_from_data, fill_missing_thresholds_from_nadir, ) +from ax.service.utils.best_point_utils import select_baseline_name_default_first_trial from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import checked_cast from botorch.utils.multi_objective.box_decompositions import DominatedPartitioning @@ -637,3 +639,73 @@ def _get_trace_by_progression( # pyre-fixme[16]: Item `List` of `Union[List[float], ndarray[typing.Any, # np.dtype[typing.Any]]]` has no attribute `squeeze`. return best_observed.tolist(), bins.squeeze(axis=0).tolist() + + def get_improvement_over_baseline( + self, + experiment: Experiment, + generation_strategy: GenerationStrategy, + baseline_arm_name: str | None = None, + ) -> float: + """Returns the scalarized improvement over baseline, if applicable. + + Returns: + For Single Objective cases, returns % improvement of objective. + Positive indicates improvement over baseline. Negative indicates regression. + For Multi Objective cases, throws NotImplementedError + """ + if experiment.is_moo_problem: + raise NotImplementedError( + "`get_improvement_over_baseline` not yet implemented" + + " for multi-objective problems." + ) + if not baseline_arm_name: + baseline_arm_name, _ = select_baseline_name_default_first_trial( + experiment=experiment, + baseline_arm_name=baseline_arm_name, + ) + + optimization_config = experiment.optimization_config + if not optimization_config: + raise ValueError("No optimization config found.") + + objective_metric_name = optimization_config.objective.metric.name + + # get the baseline trial + data = experiment.lookup_data().df + data = data[data["arm_name"] == baseline_arm_name] + if len(data) == 0: + raise UserInputError( + "`get_improvement_over_baseline`" + " could not find baseline arm" + f" `{baseline_arm_name}` in the experiment data." + ) + data = data[data["metric_name"] == objective_metric_name] + baseline_value = data.iloc[0]["mean"] + + # Find objective value of the best trial + idx, param, best_arm = none_throws( + self._get_best_trial( + experiment=experiment, + generation_strategy=generation_strategy, + optimization_config=optimization_config, + use_model_predictions=False, + ) + ) + best_arm = none_throws(best_arm) + best_obj_value = best_arm[0][objective_metric_name] + + def percent_change(x: float, y: float, minimize: bool) -> float: + if x == 0: + raise ZeroDivisionError( + "Cannot compute percent improvement when denom is zero" + ) + percent_change = (y - x) / abs(x) * 100 + if minimize: + percent_change = -percent_change + return percent_change + + return percent_change( + x=baseline_value, + y=best_obj_value, + minimize=optimization_config.objective.minimize, + ) diff --git a/ax/service/utils/best_point_utils.py b/ax/service/utils/best_point_utils.py new file mode 100644 index 00000000000..1bda8dd3219 --- /dev/null +++ b/ax/service/utils/best_point_utils.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.core.experiment import Experiment +from pyre_extensions import none_throws + +BASELINE_ARM_NAME = "baseline_arm" + + +def select_baseline_name_default_first_trial( + experiment: Experiment, baseline_arm_name: str | None +) -> tuple[str, bool]: + """ + Choose a baseline arm from arms on the experiment. Logic: + 1. If ``baseline_arm_name`` provided, validate that arm exists + and return that arm name. + 2. If ``experiment.status_quo`` is set, return its arm name. + 3. If there is at least one trial on the experiment, use the + first trial's first arm as the baseline. + 4. Error if 1-3 all don't apply. + + Returns: + Tuple: + baseline arm name (str) + true when baseline selected from first arm of experiment (bool) + raise ValueError if no valid baseline found + """ + + arms_dict = experiment.arms_by_name + + if baseline_arm_name: + if baseline_arm_name not in arms_dict: + raise ValueError(f"Arm by name {baseline_arm_name=} not found.") + return baseline_arm_name, False + + if experiment.status_quo and none_throws(experiment.status_quo).name in arms_dict: + baseline_arm_name = none_throws(experiment.status_quo).name + return baseline_arm_name, False + + if ( + experiment.trials + and experiment.trials[0].arms + and experiment.trials[0].arms[0].name in arms_dict + ): + baseline_arm_name = experiment.trials[0].arms[0].name + return baseline_arm_name, True + + else: + raise ValueError("Could not find valid baseline arm.") diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index b4bf34dd790..534fe286e32 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -61,6 +61,7 @@ plot_objective_value_vs_trial_index, ) from ax.service.utils.best_point import _derel_opt_config_wrapper, _is_row_feasible +from ax.service.utils.best_point_utils import select_baseline_name_default_first_trial from ax.service.utils.early_stopping import get_early_stopping_metrics from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import checked_cast @@ -1331,50 +1332,6 @@ def _build_result_tuple( return result -def select_baseline_arm( - experiment: Experiment, arms_df: pd.DataFrame, baseline_arm_name: str | None -) -> tuple[str, bool]: - """ - Choose a baseline arm that is found in arms_df - - Returns: - Tuple: - baseline_arm_name if valid baseline exists - true when baseline selected from first arm of sweep - raise ValueError if no valid baseline found - """ - - if baseline_arm_name: - if arms_df[arms_df["arm_name"] == baseline_arm_name].empty: - raise ValueError( - f"compare_to_baseline: baseline row: {baseline_arm_name=}" - " not found in arms" - ) - return baseline_arm_name, False - - else: - if ( - experiment.status_quo - and not arms_df[ - arms_df["arm_name"] == none_throws(experiment.status_quo).name - ].empty - ): - baseline_arm_name = none_throws(experiment.status_quo).name - return baseline_arm_name, False - - if ( - experiment.trials - and experiment.trials[0].arms - and not arms_df[ - arms_df["arm_name"] == experiment.trials[0].arms[0].name - ].empty - ): - baseline_arm_name = experiment.trials[0].arms[0].name - return baseline_arm_name, True - else: - raise ValueError("compare_to_baseline: could not find valid baseline arm") - - def maybe_extract_baseline_comparison_values( experiment: Experiment, optimization_config: OptimizationConfig | None, @@ -1425,8 +1382,8 @@ def maybe_extract_baseline_comparison_values( return None try: - baseline_arm_name, _ = select_baseline_arm( - experiment=experiment, arms_df=arms_df, baseline_arm_name=baseline_arm_name + baseline_arm_name, _ = select_baseline_name_default_first_trial( + experiment=experiment, baseline_arm_name=baseline_arm_name ) except Exception as e: logger.info(f"compare_to_baseline: could not select baseline arm. Reason: {e}") diff --git a/sphinx/source/service.rst b/sphinx/source/service.rst index e66454a97bb..70957ae263e 100644 --- a/sphinx/source/service.rst +++ b/sphinx/source/service.rst @@ -73,6 +73,12 @@ Best Point Identification :show-inheritance: +.. automodule:: ax.service.utils.best_point_utils + :members: + :undoc-members: + :show-inheritance: + + Instantiation ~~~~~~~~~~~~~