Skip to content

Commit

Permalink
Reduce complexity in wait_for_trials_and_report_results (#3175)
Browse files Browse the repository at this point in the history
Summary:

This diff contains the following changes:
1. Deprecates ESSs `seconds_between_polls`, as we can leverage SchedulerOption's `init_seconds_between_polls` instead
2. Set SchedulerOption's `seconds_between_polls_backoff_factor` to 1 if an ESS is provided
3. Use a try-catch in `wait_for_completed_trials_and_report_results` when doing `idle_callback`, so it doesn't fail the Scheduler

Reviewed By: mgarrard

Differential Revision: D67178422
  • Loading branch information
paschai authored and facebook-github-bot committed Dec 18, 2024
1 parent c5e038a commit 5e2d4dd
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 79 deletions.
13 changes: 0 additions & 13 deletions ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

import warnings
from dataclasses import dataclass

from ax.core.experiment import Experiment
Expand Down Expand Up @@ -64,18 +63,6 @@ class BenchmarkMethod(Base):
def __post_init__(self) -> None:
if self.name == "DEFAULT":
self.name = self.generation_strategy.name
early_stopping_strategy = self.early_stopping_strategy
if early_stopping_strategy is not None:
seconds_between_polls = early_stopping_strategy.seconds_between_polls
if seconds_between_polls > 0:
warnings.warn(
"`early_stopping_strategy.seconds_between_polls` is "
f"{seconds_between_polls}, but benchmarking uses 0 seconds "
"between polls. Setting "
"`early_stopping_strategy.seconds_between_polls` to 0.",
stacklevel=1,
)
early_stopping_strategy.seconds_between_polls = 0

def get_best_parameters(
self,
Expand Down
9 changes: 0 additions & 9 deletions ax/benchmark/tests/test_benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from ax.benchmark.methods.sobol import get_sobol_generation_strategy
from ax.core.experiment import Experiment
from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
from ax.utils.common.testutils import TestCase
from pyre_extensions import none_throws

Expand Down Expand Up @@ -45,14 +44,6 @@ def test_benchmark_method(self) -> None:
)
)

def test_raises_when_ess_polls_with_delay(self) -> None:
ess = ThresholdEarlyStoppingStrategy(seconds_between_polls=10)
with self.assertWarnsRegex(Warning, "seconds_between_polls"):
BenchmarkMethod(
generation_strategy=self.gs,
early_stopping_strategy=ess,
)

def test_get_best_parameters(self) -> None:
"""
This is tested more thoroughly in `test_benchmark` -- setting up an
Expand Down
10 changes: 0 additions & 10 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class BaseEarlyStoppingStrategy(ABC, Base):
def __init__(
self,
metric_names: Iterable[str] | None = None,
seconds_between_polls: int = 300,
min_progression: float | None = None,
max_progression: float | None = None,
min_curves: int | None = None,
Expand All @@ -82,8 +81,6 @@ def __init__(
Args:
metric_names: The names of the metrics the strategy will interact with.
If no metric names are provided the objective metric is assumed.
seconds_between_polls: How often to poll the early stopping metric to
evaluate whether or not the trial should be early stopped.
min_progression: Only stop trials if the latest progression value
(e.g. timestamp, epochs, training data used) is greater than this
threshold. Prevents stopping prematurely before enough data is gathered
Expand All @@ -103,10 +100,7 @@ def __init__(
should be > 0 to ensure that at least one trial has completed and that
we have a reliable approximation for `prog_max`.
"""
if seconds_between_polls < 0:
raise ValueError("`seconds_between_polls may not be less than 0.")
self.metric_names = metric_names
self.seconds_between_polls = seconds_between_polls
self.min_progression = min_progression
self.max_progression = max_progression
self.min_curves = min_curves
Expand Down Expand Up @@ -446,7 +440,6 @@ class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
def __init__(
self,
metric_names: Iterable[str] | None = None,
seconds_between_polls: int = 300,
min_progression: float | None = None,
max_progression: float | None = None,
min_curves: int | None = None,
Expand All @@ -459,8 +452,6 @@ def __init__(
Args:
metric_names: The names of the metrics the strategy will interact with.
If no metric names are provided the objective metric is assumed.
seconds_between_polls: How often to poll the early stopping metric to
evaluate whether or not the trial should be early stopped.
min_progression: Only stop trials if the latest progression value
(e.g. timestamp, epochs, training data used) is greater than this
threshold. Prevents stopping prematurely before enough data is gathered
Expand All @@ -485,7 +476,6 @@ def __init__(
"""
super().__init__(
metric_names=metric_names,
seconds_between_polls=seconds_between_polls,
min_progression=min_progression,
max_progression=max_progression,
min_curves=min_curves,
Expand Down
5 changes: 1 addition & 4 deletions ax/early_stopping/strategies/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ def __init__(
self,
left: BaseEarlyStoppingStrategy,
right: BaseEarlyStoppingStrategy,
seconds_between_polls: int = 300,
) -> None:
super().__init__(
seconds_between_polls=seconds_between_polls,
)
super().__init__()

self.left = left
self.right = right
Expand Down
4 changes: 0 additions & 4 deletions ax/early_stopping/strategies/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class PercentileEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
def __init__(
self,
metric_names: Iterable[str] | None = None,
seconds_between_polls: int = 300,
percentile_threshold: float = 50.0,
min_progression: float | None = 10,
max_progression: float | None = None,
Expand All @@ -42,8 +41,6 @@ def __init__(
metric_names: A (length-one) list of name of the metric to observe. If
None will default to the objective metric on the Experiment's
OptimizationConfig.
seconds_between_polls: How often to poll the early stopping metric to
evaluate whether or not the trial should be early stopped.
percentile_threshold: Falling below this threshold compared to other trials
at the same step will stop the run. Must be between 0.0 and 100.0.
e.g. if percentile_threshold=25.0, the bottom 25% of trials are stopped.
Expand Down Expand Up @@ -71,7 +68,6 @@ def __init__(
"""
super().__init__(
metric_names=metric_names,
seconds_between_polls=seconds_between_polls,
trial_indices_to_ignore=trial_indices_to_ignore,
min_progression=min_progression,
max_progression=max_progression,
Expand Down
4 changes: 0 additions & 4 deletions ax/early_stopping/strategies/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class ThresholdEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
def __init__(
self,
metric_names: Iterable[str] | None = None,
seconds_between_polls: int = 300,
metric_threshold: float = 0.2,
min_progression: float | None = 10,
max_progression: float | None = None,
Expand All @@ -39,8 +38,6 @@ def __init__(
metric_names: A (length-one) list of name of the metric to observe. If
None will default to the objective metric on the Experiment's
OptimizationConfig.
seconds_between_polls: How often to poll the early stopping metric to
evaluate whether or not the trial should be early stopped.
metric_threshold: The metric threshold that a trial needs to reach by
min_progression in order not to be stopped.
min_progression: Only stop trials if the latest progression value
Expand All @@ -64,7 +61,6 @@ def __init__(
"""
super().__init__(
metric_names=metric_names,
seconds_between_polls=seconds_between_polls,
min_progression=min_progression,
max_progression=max_progression,
min_curves=min_curves,
Expand Down
42 changes: 12 additions & 30 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,7 @@ def run_n_trials(
max_trials: int,
ignore_global_stopping_strategy: bool = False,
timeout_hours: float | None = None,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
idle_callback: Callable[[Scheduler], Any] | None = None,
idle_callback: Optional[Callable[[Scheduler], None]] = None,
) -> OptimizationResult:
"""Run up to ``max_trials`` trials; will run all ``max_trials`` unless
completion criterion is reached. For base ``Scheduler``, completion criterion
Expand Down Expand Up @@ -625,8 +624,7 @@ def run_n_trials(
def run_all_trials(
self,
timeout_hours: float | None = None,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
idle_callback: Callable[[Scheduler], Any] | None = None,
idle_callback: Optional[Callable[[Scheduler], None]] = None,
) -> OptimizationResult:
"""Run all trials until ``should_consider_optimization_complete`` yields
true (by default, ``should_consider_optimization_complete`` will yield true when
Expand Down Expand Up @@ -874,36 +872,15 @@ def wait_for_completed_trials_and_report_results(
dict. The contents of the dict depend on the implementation of
`report_results` in the given `Scheduler` subclass.
"""
if (
self.options.init_seconds_between_polls is None
and self.options.early_stopping_strategy is None
):
if self.options.init_seconds_between_polls is None:
raise ValueError(
"Default `wait_for_completed_trials_and_report_results` in base "
"`Scheduler` relies on non-null `init_seconds_between_polls` scheduler "
"option or for an EarlyStoppingStrategy to be specified."
)
elif (
self.options.init_seconds_between_polls is not None
and self.options.early_stopping_strategy is not None
):
self.logger.warning(
"Both `init_seconds_between_polls` and `early_stopping_strategy "
"supplied. `init_seconds_between_polls="
f"{self.options.init_seconds_between_polls}` will be overrridden by "
"`early_stopping_strategy.seconds_between_polls="
f"{self.options.early_stopping_strategy.seconds_between_polls}` and "
"polling will take place at a constant rate."
"option."
)

seconds_between_polls = self.options.init_seconds_between_polls
backoff_factor = self.options.seconds_between_polls_backoff_factor
if self.options.early_stopping_strategy is not None:
seconds_between_polls = (
self.options.early_stopping_strategy.seconds_between_polls
)
# Do not backoff with early stopping, a constant heartbeat is preferred
backoff_factor = 1

total_seconds_elapsed = 0
while len(self.pending_trials) > 0 and not self.poll_and_process_results():
Expand All @@ -912,7 +889,13 @@ def wait_for_completed_trials_and_report_results(
# criterion again and and re-attempt scheduling more trials.

if idle_callback is not None:
idle_callback(self)
try:
idle_callback(self)
except Exception as e:
self.logger.warning(
f"Exception raised in ``idle_callback``: {e}. "
"Continuing to poll for completed trials."
)

log_seconds = (
int(seconds_between_polls)
Expand Down Expand Up @@ -1558,8 +1541,7 @@ def _abort_optimization(self, num_preexisting_trials: int) -> dict[str, Any]:
def _complete_optimization(
self,
num_preexisting_trials: int,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
idle_callback: Callable[[Scheduler], Any] | None = None,
idle_callback: Optional[Callable[[Scheduler], None]] = None,
) -> dict[str, Any]:
"""Conclude optimization with waiting for anymore running trials and
return final results via `wait_for_completed_trials_and_report_results`.
Expand Down
17 changes: 14 additions & 3 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,9 +1481,7 @@ def should_stop_trials_early(
generation_strategy=gs,
options=SchedulerOptions(
init_seconds_between_polls=0,
early_stopping_strategy=OddIndexEarlyStoppingStrategy(
seconds_between_polls=1
),
early_stopping_strategy=OddIndexEarlyStoppingStrategy(),
fetch_kwargs={
"overwrite_existing_data": False,
},
Expand Down Expand Up @@ -2935,3 +2933,16 @@ def test_markdown_messages(self) -> None:
self.assertEqual(
scheduler.markdown_messages["Generation strategy"].priority, 10
)

def test_seconds_between_polls_backoff_factor_is_set(self) -> None:
options = SchedulerOptions(
**self.scheduler_options_kwargs,
)

self.assertEqual(options.seconds_between_polls_backoff_factor, 1.5)

options_with_ess = SchedulerOptions(
early_stopping_strategy=DummyEarlyStoppingStrategy(),
**self.scheduler_options_kwargs,
)
self.assertEqual(options_with_ess.seconds_between_polls_backoff_factor, 1.0)
4 changes: 4 additions & 0 deletions ax/service/utils/scheduler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,7 @@ class SchedulerOptions:
enforce_immutable_search_space_and_opt_config: bool = True
mt_experiment_trial_type: str | None = None
force_candidate_generation: bool = False

def __post_init__(self) -> None:
if self.early_stopping_strategy is not None:
object.__setattr__(self, "seconds_between_polls_backoff_factor", 1)
1 change: 0 additions & 1 deletion ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,6 @@ def percentile_early_stopping_strategy_to_dict(
"min_progression": strategy.min_progression,
"min_curves": strategy.min_curves,
"trial_indices_to_ignore": strategy.trial_indices_to_ignore,
"seconds_between_polls": strategy.seconds_between_polls,
"normalize_progressions": strategy.normalize_progressions,
}

Expand Down
1 change: 0 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2152,7 +2152,6 @@ def get_or_early_stopping_strategy() -> OrEarlyStoppingStrategy:
class DummyEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
def __init__(self, early_stop_trials: dict[int, str | None] | None = None) -> None:
self.early_stop_trials: dict[int, str | None] = early_stop_trials or {}
self.seconds_between_polls = 1

def should_stop_trials_early(
self,
Expand Down

0 comments on commit 5e2d4dd

Please sign in to comment.