Skip to content

Commit

Permalink
Move get_improvement_over_baseline to the BestPointMixin (facebook#3156)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebook#3156

Differential Revision: D66472613
  • Loading branch information
paschai authored and facebook-github-bot committed Dec 9, 2024
1 parent ec08fe5 commit 9ba35e2
Show file tree
Hide file tree
Showing 18 changed files with 685 additions and 638 deletions.
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_torch_moo_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
infer_objective_thresholds,
pareto_frontier_evaluator,
)
from ax.service.utils.report_utils import exp_to_df
from ax.service.utils.best_point_utils import exp_to_df
from ax.utils.common.random import set_rng_seed
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
Expand Down
3 changes: 2 additions & 1 deletion ax/plot/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import pandas as pd
from ax.core.experiment import Experiment
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.service.utils.report_utils import _get_shortest_unique_suffix_dict, exp_to_df
from ax.service.utils.best_point_utils import exp_to_df
from ax.service.utils.report_utils import _get_shortest_unique_suffix_dict
from plotly import express as px, graph_objs as go


Expand Down
2 changes: 1 addition & 1 deletion ax/plot/tests/test_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
optimization_trace_single_method_plotly,
plot_objective_value_vs_trial_index,
)
from ax.service.utils.report_utils import exp_to_df
from ax.service.utils.best_point_utils import exp_to_df
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import mock_botorch_optimize
Expand Down
2 changes: 1 addition & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@
from ax.plot.trace import optimization_trace_single_method
from ax.service.utils.analysis_base import AnalysisBase
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.best_point_utils import exp_to_df
from ax.service.utils.instantiation import (
FixedFeatures,
InstantiationBase,
ObjectiveProperties,
)
from ax.service.utils.report_utils import exp_to_df
from ax.service.utils.with_db_settings_base import DBSettings
from ax.storage.json_store.decoder import (
generation_strategy_from_json,
Expand Down
65 changes: 0 additions & 65 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,71 +1021,6 @@ def summarize_final_result(self) -> OptimizationResult:
"""
return OptimizationResult()

def get_improvement_over_baseline(
self,
baseline_arm_name: str | None = None,
) -> float:
"""Returns the scalarized improvement over baseline, if applicable.
Returns:
For Single Objective cases, returns % improvement of objective.
Positive indicates improvement over baseline. Negative indicates regression.
For Multi Objective cases, throws NotImplementedError
"""
if self.experiment.is_moo_problem:
raise NotImplementedError(
"`get_improvement_over_baseline` not yet implemented"
+ " for multi-objective problems."
)
if not baseline_arm_name:
raise UserInputError(
"`get_improvement_over_baseline` missing required parameter: "
+ f"{baseline_arm_name=}, "
)

optimization_config = self.experiment.optimization_config
if not optimization_config:
raise ValueError("No optimization config found.")

objective_metric_name = optimization_config.objective.metric.name

# get the baseline trial
data = self.experiment.lookup_data().df
data = data[data["arm_name"] == baseline_arm_name]
if len(data) == 0:
raise UserInputError(
"`get_improvement_over_baseline`"
" could not find baseline arm"
f" `{baseline_arm_name}` in the experiment data."
)
data = data[data["metric_name"] == objective_metric_name]
baseline_value = data.iloc[0]["mean"]

# Find objective value of the best trial
idx, param, best_arm = none_throws(
self.get_best_trial(
optimization_config=optimization_config, use_model_predictions=False
)
)
best_arm = none_throws(best_arm)
best_obj_value = best_arm[0][objective_metric_name]

def percent_change(x: float, y: float, minimize: bool) -> float:
if x == 0:
raise ZeroDivisionError(
"Cannot compute percent improvement when denom is zero"
)
percent_change = (y - x) / abs(x) * 100
if minimize:
percent_change = -percent_change
return percent_change

return percent_change(
x=baseline_value,
y=best_obj_value,
minimize=optimization_config.objective.minimize,
)

def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool:
"""Checks if the failure rate (set in scheduler options) has been exceeded at
any point during the optimization.
Expand Down
36 changes: 19 additions & 17 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,8 @@ def test_get_improvement_over_baseline(self) -> None:
scheduler.experiment.trials[0].lookup_data().df["arm_name"].iloc[0]
)
percent_improvement = scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=first_trial_name,
)

Expand All @@ -2209,11 +2211,7 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:
self.branin_experiment.optimization_config = (
get_branin_multi_objective_optimization_config()
)

gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.sobol_MBM_GS,
)
gs = self.sobol_MBM_GS

scheduler = Scheduler(
experiment=self.branin_experiment,
Expand All @@ -2227,6 +2225,8 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:

with self.assertRaises(NotImplementedError):
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=None,
)

Expand All @@ -2236,10 +2236,7 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
experiment.name = f"{self.branin_experiment.name}_but_moo"
experiment.runner = self.runner

gs = self._get_generation_strategy_strategy_for_test(
experiment=experiment,
generation_strategy=self.two_sobol_steps_GS,
)
gs = self.two_sobol_steps_GS
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
generation_strategy=gs,
Expand All @@ -2251,8 +2248,10 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
db_settings=self.db_settings_if_always_needed,
)

with self.assertRaises(UserInputError):
with self.assertRaises(ValueError):
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=None,
)

Expand All @@ -2267,19 +2266,20 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
scheduler.experiment = exp_copy

with self.assertRaises(ValueError):
scheduler.get_improvement_over_baseline(baseline_arm_name="baseline")
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name="baseline",
)

def test_get_improvement_over_baseline_no_baseline(self) -> None:
"""Test that get_improvement_over_baseline returns UserInputError when
baseline is not found in data."""
n_total_trials = 8
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)

experiment = self.branin_experiment
gs = self.two_sobol_steps_GS
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
experiment=experiment, # Has runner and metrics.
generation_strategy=gs,
options=SchedulerOptions(
total_trials=n_total_trials,
Expand All @@ -2293,6 +2293,8 @@ def test_get_improvement_over_baseline_no_baseline(self) -> None:

with self.assertRaises(UserInputError):
scheduler.get_improvement_over_baseline(
experiment=experiment,
generation_strategy=gs,
baseline_arm_name="baseline_arm_not_in_data",
)

Expand Down
100 changes: 100 additions & 0 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import copy
import random
from unittest.mock import MagicMock, patch

Expand All @@ -14,7 +15,9 @@
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import ScalarizedObjective
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint
Expand All @@ -32,10 +35,13 @@
get_best_raw_objective_point,
logger as best_point_logger,
)
from ax.service.utils.best_point_utils import select_baseline_arm
from ax.service.utils.report_utils import BASELINE_ARM_NAME
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_metric,
get_branin_search_space,
get_experiment_with_observations,
get_sobol,
)
Expand Down Expand Up @@ -556,6 +562,100 @@ def test_is_row_feasible(self) -> None:
df.index, feasible_series.index, check_names=False
)

def test_compare_to_baseline_select_baseline_arm(self) -> None:
OBJECTIVE_METRIC = "objective"
true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True)
experiment = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[true_obj_metric],
)

# specified baseline
data = [
{
"trial_index": 0,
"arm_name": "m_0",
OBJECTIVE_METRIC: 0.2,
},
{
"trial_index": 1,
"arm_name": BASELINE_ARM_NAME,
OBJECTIVE_METRIC: 0.2,
},
{
"trial_index": 2,
"arm_name": "status_quo",
OBJECTIVE_METRIC: 0.2,
},
]
arms_df = pd.DataFrame(data)
self.assertEqual(
select_baseline_arm(
experiment=experiment,
arms_df=arms_df,
baseline_arm_name=BASELINE_ARM_NAME,
),
(BASELINE_ARM_NAME, False),
)

# specified baseline arm not in trial
wrong_baseline_name = "wrong_baseline_name"
with self.assertRaisesRegex(
ValueError,
"compare_to_baseline: baseline row: .*" + " not found in arms",
):
select_baseline_arm(
experiment=experiment,
arms_df=arms_df,
baseline_arm_name=wrong_baseline_name,
)

# status quo baseline arm
experiment_with_status_quo = copy.deepcopy(experiment)
experiment_with_status_quo.status_quo = Arm(
name="status_quo",
parameters={"x1": 0, "x2": 0},
)
self.assertEqual(
select_baseline_arm(
experiment=experiment_with_status_quo,
arms_df=arms_df,
baseline_arm_name=None,
),
("status_quo", False),
)
# first arm from trials
custom_arm = Arm(name="m_0", parameters={"x1": 0.1, "x2": 0.2})
experiment.new_trial().add_arm(custom_arm)
self.assertEqual(
select_baseline_arm(
experiment=experiment,
arms_df=arms_df,
baseline_arm_name=None,
),
("m_0", True),
)

# none selected
experiment_with_no_valid_baseline = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[true_obj_metric],
)
experiment_with_no_valid_baseline.status_quo = Arm(
name="not found",
parameters={"x1": 0, "x2": 0},
)
custom_arm = Arm(name="also not found", parameters={"x1": 0.1, "x2": 0.2})
experiment_with_no_valid_baseline.new_trial().add_arm(custom_arm)
with self.assertRaisesRegex(
ValueError, "compare_to_baseline: could not find valid baseline arm"
):
select_baseline_arm(
experiment=experiment_with_no_valid_baseline,
arms_df=arms_df,
baseline_arm_name=None,
)


def _repeat_elements(list_to_replicate: list[bool], n_repeats: int) -> pd.Series:
return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)])
Loading

0 comments on commit 9ba35e2

Please sign in to comment.