Skip to content

Commit

Permalink
Accept any GSInterface in scheduler
Browse files Browse the repository at this point in the history
Differential Revision: D51307866
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 28, 2023
1 parent da7674c commit 6c15611
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 55 deletions.
14 changes: 13 additions & 1 deletion ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.core.experiment import Experiment
from ax.core.utils import get_model_times
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.service.scheduler import Scheduler
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from botorch.utils.sampling import manual_seed

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -104,7 +106,17 @@ def benchmark_replication(
scheduler.run_n_trials(max_trials=problem.num_trials)

optimization_trace = np.array(scheduler.get_trace())
num_baseline_trials = scheduler.generation_strategy._steps[0].num_trials
num_baseline_trials = (
checked_cast(
GenerationStrategy,
scheduler.generation_strategy,
exception=ValueError(
"This functionality is only supported with a local GenerationStrategy"
),
)
._steps[0]
.num_trials
)
score_trace = compute_score_trace(
optimization_trace=optimization_trace,
num_baseline_trials=num_baseline_trials,
Expand Down
3 changes: 2 additions & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import DataType, Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -1730,7 +1731,7 @@ def _set_generation_strategy(

def _save_generation_strategy_to_db_if_possible(
self,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
suppress_all_errors: bool = False,
) -> bool:
return super()._save_generation_strategy_to_db_if_possible(
Expand Down
53 changes: 31 additions & 22 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import ax.service.utils.early_stopping as early_stopping_utils
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -68,7 +69,7 @@
set_stderr_log_level,
)
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import not_none
from ax.utils.common.typeutils import checked_cast, not_none
from pyre_extensions import assert_is_instance


Expand Down Expand Up @@ -161,7 +162,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin):
"""

experiment: Experiment
generation_strategy: GenerationStrategy
generation_strategy: GenerationStrategyInterface
options: SchedulerOptions
logger: LoggerAdapter
# Mapping of form {short string identifier -> message to show in reported
Expand Down Expand Up @@ -205,7 +206,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin):
def __init__(
self,
experiment: Experiment,
generation_strategy: GenerationStrategy,
generation_strategy: GenerationStrategyInterface,
options: SchedulerOptions,
db_settings: Optional[DBSettings] = None,
_skip_experiment_save: bool = False,
Expand All @@ -220,7 +221,7 @@ def __init__(

if not isinstance(experiment, Experiment):
raise TypeError("{experiment} is not an Ax experiment.")
if not isinstance(generation_strategy, GenerationStrategy):
if not isinstance(generation_strategy, GenerationStrategyInterface):
raise TypeError("{generation_strategy} is not a generation strategy.")
self._validate_options(options=options)

Expand Down Expand Up @@ -381,6 +382,19 @@ def runner(self) -> Runner:
"""
return not_none(self.experiment.runner)

@property
def concrete_generation_strategy(self) -> GenerationStrategy:
"""Concrete ``GenerationStrategy`` associated with this ``Scheduler``
instance.
"""
return checked_cast(
GenerationStrategy,
self.generation_strategy,
UnsupportedError(
"This functionality is only supported with a local GenerationStrategy"
),
)

def __repr__(self) -> str:
"""Short user-friendly string representation."""
if not hasattr(self, "experiment"):
Expand Down Expand Up @@ -446,7 +460,7 @@ def get_best_trial(
) -> Optional[Tuple[int, TParameterization, Optional[TModelPredictArm]]]:
return self._get_best_trial(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.concrete_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand All @@ -461,7 +475,7 @@ def get_pareto_optimal_parameters(
) -> Optional[Dict[int, Tuple[TParameterization, TModelPredictArm]]]:
return self._get_pareto_optimal_parameters(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.concrete_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand All @@ -476,7 +490,7 @@ def get_hypervolume(
) -> float:
return BestPointMixin._get_hypervolume(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.concrete_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand Down Expand Up @@ -1449,12 +1463,9 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]:
Returns:
List of trials, empty if generation is not possible.
"""
pending = get_pending_observation_features_based_on_trial_status(
experiment=self.experiment
)
try:
generator_runs = self._gen_new_trials_from_generation_strategy(
num_trials=num_trials, n=n, pending=pending
num_trials=num_trials, n=n
)
except OptimizationComplete as err:
completion_str = f"Optimization complete: {err}"
Expand Down Expand Up @@ -1483,42 +1494,40 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]:
self.logger.debug(f"Message from generation strategy: {err}")
return []

if (
self.options.trial_type == TrialType.TRIAL
and len(generator_runs[0].arms) > 1
if self.options.trial_type == TrialType.TRIAL and any(
len(generator_run_list[0].arms) > 1 or len(generator_run_list) > 1
for generator_run_list in generator_runs
):
raise SchedulerInternalError(
"Generation strategy produced multiple arms when only one was expected."
)

return [
self.experiment.new_batch_trial(
generator_run=generator_run,
generator_runs=generator_run_list,
ttl_seconds=self.options.ttl_seconds_for_trials,
)
if self.options.trial_type == TrialType.BATCH_TRIAL
else self.experiment.new_trial(
generator_run=generator_run,
generator_run=generator_run_list[0],
ttl_seconds=self.options.ttl_seconds_for_trials,
)
for generator_run in generator_runs
for generator_run_list in generator_runs
]

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
n: int,
pending: Optional[Dict[str, List[ObservationFeatures]]],
) -> List[GeneratorRun]:
) -> List[List[GeneratorRun]]:
"""Generates a list ``GeneratorRun``s of length of ``num_trials`` using the
``_gen_multiple`` method of the scheduler's ``generation_strategy``, taking
into account any ``pending`` observations.
"""
return self.generation_strategy._gen_multiple(
return self.generation_strategy.gen_for_multiple_trials_with_multiple_models(
experiment=self.experiment,
num_generator_runs=num_trials,
n=n,
pending_observations=pending,
)

def _update_and_save_trials(
Expand Down Expand Up @@ -1853,7 +1862,7 @@ def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge:
Returns:
A ModelBridge object fitted to the observations of the scheduler's experiment.
"""
gs = scheduler.generation_strategy # GenerationStrategy
gs = scheduler.concrete_generation_strategy
model_bridge = gs.model # Optional[ModelBridge]
if model_bridge is None: # Need to re-fit the model.
data = scheduler.experiment.fetch_data()
Expand Down
16 changes: 9 additions & 7 deletions ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,10 @@ def test_validate_early_stopping_strategy(self) -> None:
),
)

@patch(
f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
return_value=[get_generator_run()],
@patch.object(
GenerationStrategy,
"gen_for_multiple_trials_with_multiple_models",
return_value=[[get_generator_run()]],
)
def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None:
scheduler = Scheduler(
Expand All @@ -366,7 +367,7 @@ def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None:
@patch(
# Record calls to function, but still execute it.
(
f"{Scheduler.__module__}."
f"{GenerationStrategy.__module__}."
"get_pending_observation_features_based_on_trial_status"
),
side_effect=get_pending_observation_features_based_on_trial_status,
Expand Down Expand Up @@ -987,8 +988,9 @@ def test_base_report_results(self) -> None:
)
self.assertEqual(scheduler.run_n_trials(max_trials=3), OptimizationResult())

@patch(
f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
@patch.object(
GenerationStrategy,
"gen_for_multiple_trials_with_multiple_models",
side_effect=OptimizationComplete("test error"),
)
def test_optimization_complete(self, _) -> None:
Expand Down Expand Up @@ -1327,7 +1329,7 @@ def test_get_fitted_model_bridge(self) -> None:
with patch.object(
GenerationStrategy,
"_fit_current_model",
wraps=scheduler.generation_strategy._fit_current_model,
wraps=generation_strategy._fit_current_model,
) as fit_model:
get_fitted_model_bridge(scheduler)
fit_model.assert_called_once()
Expand Down
45 changes: 26 additions & 19 deletions ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.exceptions.core import (
IncompatibleDependencyVersion,
Expand Down Expand Up @@ -155,7 +156,7 @@ def _get_experiment_and_generation_strategy_db_id(
return exp_id, gs_id

def _maybe_save_experiment_and_generation_strategy(
self, experiment: Experiment, generation_strategy: GenerationStrategy
self, experiment: Experiment, generation_strategy: GenerationStrategyInterface
) -> Tuple[bool, bool]:
"""If DB settings are set on this `WithDBSettingsBase` instance, checks
whether given experiment and generation strategy are already saved and
Expand Down Expand Up @@ -304,7 +305,7 @@ def _save_or_update_trials_and_generation_strategy_if_possible(
self,
experiment: Experiment,
trials: List[BaseTrial],
generation_strategy: GenerationStrategy,
generation_strategy: GenerationStrategyInterface,
new_generator_runs: List[GeneratorRun],
reduce_state_generator_runs: bool = False,
) -> None:
Expand Down Expand Up @@ -386,7 +387,7 @@ def _save_or_update_trials_in_db_if_possible(

def _save_generation_strategy_to_db_if_possible(
self,
generation_strategy: Optional[GenerationStrategy] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
suppress_all_errors: bool = False,
) -> bool:
"""Saves given generation strategy if DB settings are set on this
Expand All @@ -399,18 +400,21 @@ def _save_generation_strategy_to_db_if_possible(
bool: Whether the generation strategy was saved.
"""
if self.db_settings_set and generation_strategy is not None:
_save_generation_strategy_to_db_if_possible(
generation_strategy=generation_strategy,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
)
# only local GenerationStrategies should need to be saved to
# the database because only they make changes locally
if isinstance(generation_strategy, GenerationStrategy):
_save_generation_strategy_to_db_if_possible(
generation_strategy=generation_strategy,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
)
return True
return False

def _update_generation_strategy_in_db_if_possible(
self,
generation_strategy: GenerationStrategy,
generation_strategy: GenerationStrategyInterface,
new_generator_runs: List[GeneratorRun],
reduce_state_generator_runs: bool = False,
) -> bool:
Expand All @@ -427,15 +431,18 @@ def _update_generation_strategy_in_db_if_possible(
bool: Whether the experiment was saved.
"""
if self.db_settings_set:
_update_generation_strategy_in_db_if_possible(
generation_strategy=generation_strategy,
new_generator_runs=new_generator_runs,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
reduce_state_generator_runs=reduce_state_generator_runs,
)
return True
# only local GenerationStrategies should need to be saved to
# the database because only they make changes locally
if isinstance(generation_strategy, GenerationStrategy):
_update_generation_strategy_in_db_if_possible(
generation_strategy=generation_strategy,
new_generator_runs=new_generator_runs,
encoder=self.db_settings.encoder,
decoder=self.db_settings.decoder,
suppress_all_errors=self._suppress_all_errors,
reduce_state_generator_runs=reduce_state_generator_runs,
)
return True
return False

def _update_experiment_properties_in_db(
Expand Down
1 change: 0 additions & 1 deletion ax/telemetry/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from math import inf

from ax.modelbridge.generation_strategy import GenerationStrategy

from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS


Expand Down
4 changes: 2 additions & 2 deletions ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord:
),
generation_strategy_created_record=(
GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=scheduler.generation_strategy
generation_strategy=scheduler.concrete_generation_strategy
)
),
scheduler_total_trials=scheduler.options.total_trials,
Expand All @@ -68,7 +68,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord:
),
transformed_dimensionality=_get_max_transformed_dimensionality(
search_space=scheduler.experiment.search_space,
generation_strategy=scheduler.generation_strategy,
generation_strategy=scheduler.concrete_generation_strategy,
),
)

Expand Down
4 changes: 2 additions & 2 deletions ax/telemetry/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None:
),
generation_strategy_created_record=(
GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=scheduler.generation_strategy
generation_strategy=scheduler.concrete_generation_strategy
)
),
scheduler_total_trials=0,
Expand All @@ -63,7 +63,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None:
experiment=scheduler.experiment
).__dict__,
**GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=scheduler.generation_strategy
generation_strategy=scheduler.concrete_generation_strategy
).__dict__,
"scheduler_total_trials": 0,
"scheduler_max_pending_trials": 10,
Expand Down

0 comments on commit 6c15611

Please sign in to comment.