Skip to content

Commit

Permalink
Move get_improvement_over_baseline to the BestPointMixin (facebook#3156)
Browse files Browse the repository at this point in the history
Summary:

As titled - in this change we:

1. Move get_improvement_over_baseline to the BestPointMixin, and required an experiment and generation strategy
1. remove the required parameter of `baseline_arm_name` and utilize `select_baseline_arm()` if not provided
2. Introduce `best_point_utils.py` file to store `select_baseline_arm()`
3. Modified `select_baseline_arm()` to take in a dictionary of arm_names to Arms instead of a experiment df

Differential Revision: D66472613
  • Loading branch information
paschai authored and facebook-github-bot committed Dec 12, 2024
1 parent f4aa969 commit d6a52a8
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 228 deletions.
65 changes: 0 additions & 65 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,71 +1021,6 @@ def summarize_final_result(self) -> OptimizationResult:
"""
return OptimizationResult()

def get_improvement_over_baseline(
self,
baseline_arm_name: str | None = None,
) -> float:
"""Returns the scalarized improvement over baseline, if applicable.
Returns:
For Single Objective cases, returns % improvement of objective.
Positive indicates improvement over baseline. Negative indicates regression.
For Multi Objective cases, throws NotImplementedError
"""
if self.experiment.is_moo_problem:
raise NotImplementedError(
"`get_improvement_over_baseline` not yet implemented"
+ " for multi-objective problems."
)
if not baseline_arm_name:
raise UserInputError(
"`get_improvement_over_baseline` missing required parameter: "
+ f"{baseline_arm_name=}, "
)

optimization_config = self.experiment.optimization_config
if not optimization_config:
raise ValueError("No optimization config found.")

objective_metric_name = optimization_config.objective.metric.name

# get the baseline trial
data = self.experiment.lookup_data().df
data = data[data["arm_name"] == baseline_arm_name]
if len(data) == 0:
raise UserInputError(
"`get_improvement_over_baseline`"
" could not find baseline arm"
f" `{baseline_arm_name}` in the experiment data."
)
data = data[data["metric_name"] == objective_metric_name]
baseline_value = data.iloc[0]["mean"]

# Find objective value of the best trial
idx, param, best_arm = none_throws(
self.get_best_trial(
optimization_config=optimization_config, use_model_predictions=False
)
)
best_arm = none_throws(best_arm)
best_obj_value = best_arm[0][objective_metric_name]

def percent_change(x: float, y: float, minimize: bool) -> float:
if x == 0:
raise ZeroDivisionError(
"Cannot compute percent improvement when denom is zero"
)
percent_change = (y - x) / abs(x) * 100
if minimize:
percent_change = -percent_change
return percent_change

return percent_change(
x=baseline_value,
y=best_obj_value,
minimize=optimization_config.objective.minimize,
)

def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool:
"""Checks if the failure rate (set in scheduler options) has been exceeded at
any point during the optimization.
Expand Down
36 changes: 19 additions & 17 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,8 @@ def test_get_improvement_over_baseline(self) -> None:
scheduler.experiment.trials[0].lookup_data().df["arm_name"].iloc[0]
)
percent_improvement = scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=first_trial_name,
)

Expand All @@ -2209,11 +2211,7 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:
self.branin_experiment.optimization_config = (
get_branin_multi_objective_optimization_config()
)

gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.sobol_MBM_GS,
)
gs = self.sobol_MBM_GS

scheduler = Scheduler(
experiment=self.branin_experiment,
Expand All @@ -2227,6 +2225,8 @@ def test_get_improvement_over_baseline_robustness_not_implemented(self) -> None:

with self.assertRaises(NotImplementedError):
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=None,
)

Expand All @@ -2236,10 +2236,7 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
experiment.name = f"{self.branin_experiment.name}_but_moo"
experiment.runner = self.runner

gs = self._get_generation_strategy_strategy_for_test(
experiment=experiment,
generation_strategy=self.two_sobol_steps_GS,
)
gs = self.two_sobol_steps_GS
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
generation_strategy=gs,
Expand All @@ -2251,8 +2248,10 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
db_settings=self.db_settings_if_always_needed,
)

with self.assertRaises(UserInputError):
with self.assertRaises(ValueError):
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name=None,
)

Expand All @@ -2267,19 +2266,20 @@ def test_get_improvement_over_baseline_robustness_user_input_error(self) -> None
scheduler.experiment = exp_copy

with self.assertRaises(ValueError):
scheduler.get_improvement_over_baseline(baseline_arm_name="baseline")
scheduler.get_improvement_over_baseline(
experiment=scheduler.experiment,
generation_strategy=scheduler.standard_generation_strategy,
baseline_arm_name="baseline",
)

def test_get_improvement_over_baseline_no_baseline(self) -> None:
"""Test that get_improvement_over_baseline returns UserInputError when
baseline is not found in data."""
n_total_trials = 8
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)

experiment = self.branin_experiment
gs = self.two_sobol_steps_GS
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
experiment=experiment, # Has runner and metrics.
generation_strategy=gs,
options=SchedulerOptions(
total_trials=n_total_trials,
Expand All @@ -2293,6 +2293,8 @@ def test_get_improvement_over_baseline_no_baseline(self) -> None:

with self.assertRaises(UserInputError):
scheduler.get_improvement_over_baseline(
experiment=experiment,
generation_strategy=gs,
baseline_arm_name="baseline_arm_not_in_data",
)

Expand Down
78 changes: 77 additions & 1 deletion ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

# pyre-strict

import copy
import random
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, PropertyMock

import pandas as pd
import torch
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import ScalarizedObjective
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint
Expand All @@ -32,10 +35,12 @@
get_best_raw_objective_point,
logger as best_point_logger,
)
from ax.service.utils.best_point_utils import select_baseline_name_default_first_trial
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_branin_metric,
get_branin_search_space,
get_experiment_with_observations,
get_sobol,
)
Expand Down Expand Up @@ -556,6 +561,77 @@ def test_is_row_feasible(self) -> None:
df.index, feasible_series.index, check_names=False
)

def test_compare_to_baseline_select_baseline_name_default_first_trial(self) -> None:
OBJECTIVE_METRIC = "objective"
true_obj_metric = Metric(name=OBJECTIVE_METRIC, lower_is_better=True)
experiment = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[true_obj_metric],
)

with patch.object(
Experiment, "arms_by_name", new_callable=PropertyMock
) as mock_arms_by_name:
mock_arms_by_name.return_value = {"arm1": "value1", "arm2": "value2"}
self.assertEqual(
select_baseline_name_default_first_trial(
experiment=experiment,
baseline_arm_name="arm1",
),
("arm1", False),
)

# specified baseline arm not in trial
wrong_baseline_name = "wrong_baseline_name"
with self.assertRaisesRegex(
ValueError,
"select_baseline_name_default_first_trial: Arm by name .*" + " not found.",
):
select_baseline_name_default_first_trial(
experiment=experiment,
baseline_arm_name=wrong_baseline_name,
)

# status quo baseline arm
experiment_with_status_quo = copy.deepcopy(experiment)
experiment_with_status_quo.status_quo = Arm(
name="status_quo",
parameters={"x1": 0, "x2": 0},
)
self.assertEqual(
select_baseline_name_default_first_trial(
experiment=experiment_with_status_quo,
baseline_arm_name=None,
),
("status_quo", False),
)
# first arm from trials
custom_arm = Arm(name="m_0", parameters={"x1": 0.1, "x2": 0.2})
experiment.new_trial().add_arm(custom_arm)
self.assertEqual(
select_baseline_name_default_first_trial(
experiment=experiment,
baseline_arm_name=None,
),
("m_0", True),
)

# none selected
experiment_with_no_valid_baseline = Experiment(
search_space=get_branin_search_space(),
tracking_metrics=[true_obj_metric],
)

with self.assertRaisesRegex(
ValueError,
"select_baseline_name_default_first_trial:"
" could not find valid baseline arm",
):
select_baseline_name_default_first_trial(
experiment=experiment_with_no_valid_baseline,
baseline_arm_name=None,
)


def _repeat_elements(list_to_replicate: list[bool], n_repeats: int) -> pd.Series:
return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)])
Loading

0 comments on commit d6a52a8

Please sign in to comment.