diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index c48709092cd..0996a744b49 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -104,7 +104,7 @@ def benchmark_replication( scheduler.run_n_trials(max_trials=problem.num_trials) optimization_trace = np.array(scheduler.get_trace()) - num_baseline_trials = scheduler.generation_strategy._steps[0].num_trials + num_baseline_trials = scheduler.standard_generation_strategy._steps[0].num_trials score_trace = compute_score_trace( optimization_trace=optimization_trace, num_baseline_trials=num_baseline_trials, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 6df0aa5dd25..9f4e9a3b192 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -29,6 +29,7 @@ import ax.service.utils.early_stopping as early_stopping_utils from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData from ax.core.map_metric import MapMetric @@ -161,7 +162,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin): """ experiment: Experiment - generation_strategy: GenerationStrategy + generation_strategy: GenerationStrategyInterface options: SchedulerOptions logger: LoggerAdapter # Mapping of form {short string identifier -> message to show in reported @@ -205,7 +206,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin): def __init__( self, experiment: Experiment, - generation_strategy: GenerationStrategy, + generation_strategy: GenerationStrategyInterface, options: SchedulerOptions, db_settings: Optional[DBSettings] = None, _skip_experiment_save: bool = False, @@ -220,7 +221,7 @@ def __init__( if not isinstance(experiment, Experiment): raise TypeError("{experiment} is not an Ax experiment.") - if not isinstance(generation_strategy, GenerationStrategy): + if not isinstance(generation_strategy, GenerationStrategyInterface): raise TypeError("{generation_strategy} is not a generation strategy.") self._validate_options(options=options) @@ -381,6 +382,21 @@ def runner(self) -> Runner: """ return not_none(self.experiment.runner) + @property + def standard_generation_strategy(self) -> GenerationStrategy: + """Used for operations in the scheduler that can only be done with + and instance of ``GenerationStrategy``. + """ + gs = self.generation_strategy + if not isinstance(gs, GenerationStrategy): + raise NotImplementedError( + "This functionality is only supported with instances of " + "`GenerationStrategy` (one that uses `GenerationStrategy` " + "class) and not yet with other types of " + "`GenerationStrategyInterface`." + ) + return gs + def __repr__(self) -> str: """Short user-friendly string representation.""" if not hasattr(self, "experiment"): @@ -446,7 +462,7 @@ def get_best_trial( ) -> Optional[Tuple[int, TParameterization, Optional[TModelPredictArm]]]: return self._get_best_trial( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -461,7 +477,7 @@ def get_pareto_optimal_parameters( ) -> Optional[Dict[int, Tuple[TParameterization, TModelPredictArm]]]: return self._get_pareto_optimal_parameters( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -476,7 +492,7 @@ def get_hypervolume( ) -> float: return BestPointMixin._get_hypervolume( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -1449,12 +1465,9 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]: Returns: List of trials, empty if generation is not possible. """ - pending = get_pending_observation_features_based_on_trial_status( - experiment=self.experiment - ) try: generator_runs = self._gen_new_trials_from_generation_strategy( - num_trials=num_trials, n=n, pending=pending + num_trials=num_trials, n=n ) except OptimizationComplete as err: completion_str = f"Optimization complete: {err}" @@ -1483,9 +1496,9 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]: self.logger.debug(f"Message from generation strategy: {err}") return [] - if ( - self.options.trial_type == TrialType.TRIAL - and len(generator_runs[0].arms) > 1 + if self.options.trial_type == TrialType.TRIAL and any( + len(generator_run_list[0].arms) > 1 or len(generator_run_list) > 1 + for generator_run_list in generator_runs ): raise SchedulerInternalError( "Generation strategy produced multiple arms when only one was expected." @@ -1493,32 +1506,30 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]: return [ self.experiment.new_batch_trial( - generator_run=generator_run, + generator_runs=generator_run_list, ttl_seconds=self.options.ttl_seconds_for_trials, ) if self.options.trial_type == TrialType.BATCH_TRIAL else self.experiment.new_trial( - generator_run=generator_run, + generator_run=generator_run_list[0], ttl_seconds=self.options.ttl_seconds_for_trials, ) - for generator_run in generator_runs + for generator_run_list in generator_runs ] def _gen_new_trials_from_generation_strategy( self, num_trials: int, n: int, - pending: Optional[Dict[str, List[ObservationFeatures]]], - ) -> List[GeneratorRun]: + ) -> List[List[GeneratorRun]]: """Generates a list ``GeneratorRun``s of length of ``num_trials`` using the ``_gen_multiple`` method of the scheduler's ``generation_strategy``, taking into account any ``pending`` observations. """ - return self.generation_strategy._gen_multiple( + return self.generation_strategy.gen_for_multiple_trials_with_multiple_models( experiment=self.experiment, num_generator_runs=num_trials, n=n, - pending_observations=pending, ) def _update_and_save_trials( @@ -1853,7 +1864,7 @@ def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge: Returns: A ModelBridge object fitted to the observations of the scheduler's experiment. """ - gs = scheduler.generation_strategy # GenerationStrategy + gs = scheduler.standard_generation_strategy model_bridge = gs.model # Optional[ModelBridge] if model_bridge is None: # Need to re-fit the model. data = scheduler.experiment.fetch_data() diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 7a589b00c20..9626498caa2 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -59,6 +59,7 @@ get_branin_search_space, get_generator_run, get_sobol, + SpecialGenerationStrategy, ) from pyre_extensions import none_throws @@ -350,9 +351,10 @@ def test_validate_early_stopping_strategy(self) -> None: ), ) - @patch( - f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple", - return_value=[get_generator_run()], + @patch.object( + GenerationStrategy, + "gen_for_multiple_trials_with_multiple_models", + return_value=[[get_generator_run()]], ) def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None: scheduler = Scheduler( @@ -366,7 +368,7 @@ def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None: @patch( # Record calls to function, but still execute it. ( - f"{Scheduler.__module__}." + f"{GenerationStrategy.__module__}." "get_pending_observation_features_based_on_trial_status" ), side_effect=get_pending_observation_features_based_on_trial_status, @@ -987,8 +989,9 @@ def test_base_report_results(self) -> None: ) self.assertEqual(scheduler.run_n_trials(max_trials=3), OptimizationResult()) - @patch( - f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple", + @patch.object( + GenerationStrategy, + "gen_for_multiple_trials_with_multiple_models", side_effect=OptimizationComplete("test error"), ) def test_optimization_complete(self, _) -> None: @@ -1327,7 +1330,7 @@ def test_get_fitted_model_bridge(self) -> None: with patch.object( GenerationStrategy, "_fit_current_model", - wraps=scheduler.generation_strategy._fit_current_model, + wraps=generation_strategy._fit_current_model, ) as fit_model: get_fitted_model_bridge(scheduler) fit_model.assert_called_once() @@ -1359,3 +1362,25 @@ def test_get_fitted_model_bridge(self) -> None: ) self.assertIsInstance(empty_metrics, dict) self.assertTrue(len(empty_metrics) == 0) + + def test_standard_generation_strategy(self) -> None: + with self.subTest("with a `GenerationStrategy"): + # Tests standard GS creation. + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GPEI_GS, + options=SchedulerOptions(), + ) + self.assertEqual(scheduler.standard_generation_strategy, self.sobol_GPEI_GS) + + with self.subTest("with a `SpecialGenerationStrategy`"): + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=SpecialGenerationStrategy(), + options=SchedulerOptions(), + ) + with self.assertRaisesRegex( + NotImplementedError, + "only supported with instances of `GenerationStrategy`", + ): + scheduler.standard_generation_strategy diff --git a/ax/telemetry/generation_strategy.py b/ax/telemetry/generation_strategy.py index b510e63d5f9..4b9c95f9cee 100644 --- a/ax/telemetry/generation_strategy.py +++ b/ax/telemetry/generation_strategy.py @@ -9,7 +9,6 @@ from math import inf from ax.modelbridge.generation_strategy import GenerationStrategy - from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS diff --git a/ax/telemetry/scheduler.py b/ax/telemetry/scheduler.py index f288292fa4f..4e16ee905e0 100644 --- a/ax/telemetry/scheduler.py +++ b/ax/telemetry/scheduler.py @@ -49,7 +49,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord: ), generation_strategy_created_record=( GenerationStrategyCreatedRecord.from_generation_strategy( - generation_strategy=scheduler.generation_strategy + generation_strategy=scheduler.standard_generation_strategy ) ), scheduler_total_trials=scheduler.options.total_trials, @@ -68,7 +68,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord: ), transformed_dimensionality=_get_max_transformed_dimensionality( search_space=scheduler.experiment.search_space, - generation_strategy=scheduler.generation_strategy, + generation_strategy=scheduler.standard_generation_strategy, ), ) diff --git a/ax/telemetry/tests/test_scheduler.py b/ax/telemetry/tests/test_scheduler.py index c674e588e90..e07e1d0a612 100644 --- a/ax/telemetry/tests/test_scheduler.py +++ b/ax/telemetry/tests/test_scheduler.py @@ -45,7 +45,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None: ), generation_strategy_created_record=( GenerationStrategyCreatedRecord.from_generation_strategy( - generation_strategy=scheduler.generation_strategy + generation_strategy=scheduler.standard_generation_strategy ) ), scheduler_total_trials=0, @@ -63,7 +63,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None: experiment=scheduler.experiment ).__dict__, **GenerationStrategyCreatedRecord.from_generation_strategy( - generation_strategy=scheduler.generation_strategy + generation_strategy=scheduler.standard_generation_strategy ).__dict__, "scheduler_total_trials": 0, "scheduler_max_pending_trials": 10,