Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move get_improvement_over_baseline to the BestPointMixin #3156

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 0 additions & 65 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,71 +1021,6 @@ def summarize_final_result(self) -> OptimizationResult:
"""
return OptimizationResult()

def get_improvement_over_baseline(
self,
baseline_arm_name: str | None = None,
) -> float:
"""Returns the scalarized improvement over baseline, if applicable.

Returns:
For Single Objective cases, returns % improvement of objective.
Positive indicates improvement over baseline. Negative indicates regression.
For Multi Objective cases, throws NotImplementedError
"""
if self.experiment.is_moo_problem:
raise NotImplementedError(
"`get_improvement_over_baseline` not yet implemented"
+ " for multi-objective problems."
)
if not baseline_arm_name:
raise UserInputError(
"`get_improvement_over_baseline` missing required parameter: "
+ f"{baseline_arm_name=}, "
)

optimization_config = self.experiment.optimization_config
if not optimization_config:
raise ValueError("No optimization config found.")

objective_metric_name = optimization_config.objective.metric.name

# get the baseline trial
data = self.experiment.lookup_data().df
data = data[data["arm_name"] == baseline_arm_name]
if len(data) == 0:
raise UserInputError(
"`get_improvement_over_baseline`"
" could not find baseline arm"
f" `{baseline_arm_name}` in the experiment data."
)
data = data[data["metric_name"] == objective_metric_name]
baseline_value = data.iloc[0]["mean"]

# Find objective value of the best trial
idx, param, best_arm = none_throws(
self.get_best_trial(
optimization_config=optimization_config, use_model_predictions=False
)
)
best_arm = none_throws(best_arm)
best_obj_value = best_arm[0][objective_metric_name]

def percent_change(x: float, y: float, minimize: bool) -> float:
if x == 0:
raise ZeroDivisionError(
"Cannot compute percent improvement when denom is zero"
)
percent_change = (y - x) / abs(x) * 100
if minimize:
percent_change = -percent_change
return percent_change

return percent_change(
x=baseline_value,
y=best_obj_value,
minimize=optimization_config.objective.minimize,
)

def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool:
"""Checks if the failure rate (set in scheduler options) has been exceeded at
any point during the optimization.
Expand Down
36 changes: 19 additions & 17 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,8 @@ def test_get_improvement_over_baseline(self) -> None:
scheduler.experiment.trials[0].lookup_data().df["arm_name"].iloc[0]
)
percent_improvement = scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=first_trial_name,
)

Expand All @@ -2209,11 +2211,7 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:
self.branin_experiment.optimization_config = (
get_branin_multi_objective_optimization_config()
)

gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.sobol_MBM_GS,
)
gs = self.sobol_MBM_GS

scheduler = Scheduler(
experiment=self.branin_experiment,
Expand All @@ -2227,6 +2225,8 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:

with self.assertRaises(NotImplementedError):
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=None,
)

Expand All @@ -2236,10 +2236,7 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
experiment.name = f"{self.branin_experiment.name}_but_moo"
experiment.runner = self.runner

gs = self._get_generation_strategy_strategy_for_test(
experiment=experiment,
generation_strategy=self.two_sobol_steps_GS,
)
gs = self.two_sobol_steps_GS
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
generation_strategy=gs,
Expand All @@ -2251,8 +2248,10 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
db_settings=self.db_settings_if_always_needed,
)

with self.assertRaises(UserInputError):
with self.assertRaises(ValueError):
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=None,
)

Expand All @@ -2267,19 +2266,20 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
scheduler.experiment = exp_copy

with self.assertRaises(ValueError):
scheduler.get_improvement_over_baseline(baseline_arm_name="baseline")
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name="baseline",
)

def test_get_improvement_over_baseline_no_baseline(self) -> None:
"""Test that get_improvement_over_baseline returns UserInputError when
baseline is not found in data."""
n_total_trials = 8
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)

experiment = self.branin_experiment
gs = self.two_sobol_steps_GS
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
experiment=experiment, # Has runner and metrics.
generation_strategy=gs,
options=SchedulerOptions(
total_trials=n_total_trials,
Expand All @@ -2293,6 +2293,8 @@ def test_get_improvement_over_baseline_no_baseline(self) -> None:

with self.assertRaises(UserInputError):
scheduler.get_improvement_over_baseline(
experiment=experiment,
generation_strategy=gs,
baseline_arm_name="baseline_arm_not_in_data",
)

Expand Down
77 changes: 76 additions & 1 deletion ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

# pyre-strict

import copy
import random
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, PropertyMock

import pandas as pd
import torch
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import ScalarizedObjective
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint
Expand All @@ -32,10 +35,12 @@
get_best_raw_objective_point,
logger as best_point_logger,
)
from ax.service.utils.best_point_utils import select_baseline_name_default_first_trial
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_metric,
get_branin_search_space,
get_experiment_with_observations,
get_sobol,
)
Expand Down Expand Up @@ -556,6 +561,76 @@ def test_is_row_feasible(self) -> None:
df.index, feasible_series.index, check_names=False
)

def test_compare_to_baseline_select_baseline_name_default_first_trial(self) -> None:
OBJECTIVE_METRIC = "objective"
true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True)
experiment = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[true_obj_metric],
)

with patch.object(
Experiment, "arms_by_name", new_callable=PropertyMock
) as mock_arms_by_name:
mock_arms_by_name.return_value = {"arm1": "value1", "arm2": "value2"}
self.assertEqual(
select_baseline_name_default_first_trial(
experiment=experiment,
baseline_arm_name="arm1",
),
("arm1", False),
)

# specified baseline arm not in trial
wrong_baseline_name = "wrong_baseline_name"
with self.assertRaisesRegex(
ValueError,
"Arm by name .*" + " not found.",
):
select_baseline_name_default_first_trial(
experiment=experiment,
baseline_arm_name=wrong_baseline_name,
)

# status quo baseline arm
experiment_with_status_quo = copy.deepcopy(experiment)
experiment_with_status_quo.status_quo = Arm(
name="status_quo",
parameters={"x1": 0, "x2": 0},
)
self.assertEqual(
select_baseline_name_default_first_trial(
experiment=experiment_with_status_quo,
baseline_arm_name=None,
),
("status_quo", False),
)
# first arm from trials
custom_arm = Arm(name="m_0", parameters={"x1": 0.1, "x2": 0.2})
experiment.new_trial().add_arm(custom_arm)
self.assertEqual(
select_baseline_name_default_first_trial(
experiment=experiment,
baseline_arm_name=None,
),
("m_0", True),
)

# none selected
experiment_with_no_valid_baseline = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[true_obj_metric],
)

with self.assertRaisesRegex(
ValueError,
"Could not find valid baseline arm.",
):
select_baseline_name_default_first_trial(
experiment=experiment_with_no_valid_baseline,
baseline_arm_name=None,
)


def _repeat_elements(list_to_replicate: list[bool], n_repeats: int) -> pd.Series:
return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)])
Loading
Loading