diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index c48709092cd..a943b6f632b 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -32,8 +32,10 @@ from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.core.experiment import Experiment from ax.core.utils import get_model_times +from ax.modelbridge.generation_strategy import GenerationStrategy from ax.service.scheduler import Scheduler from ax.utils.common.logger import get_logger +from ax.utils.common.typeutils import checked_cast from botorch.utils.sampling import manual_seed logger: Logger = get_logger(__name__) @@ -104,7 +106,17 @@ def benchmark_replication( scheduler.run_n_trials(max_trials=problem.num_trials) optimization_trace = np.array(scheduler.get_trace()) - num_baseline_trials = scheduler.generation_strategy._steps[0].num_trials + num_baseline_trials = ( + checked_cast( + GenerationStrategy, + scheduler.generation_strategy, + exception=ValueError( + "This functionality is only supported with a local GenerationStrategy" + ), + ) + ._steps[0] + .num_trials + ) score_trace = compute_score_trace( optimization_trace=optimization_trace, num_baseline_trials=num_baseline_trials, diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 8bda0cfd095..3caee6a4625 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -32,6 +32,7 @@ from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import DataType, Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData from ax.core.map_metric import MapMetric @@ -1730,7 +1731,7 @@ def _set_generation_strategy( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: Optional[GenerationStrategy] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, suppress_all_errors: bool = False, ) -> bool: return super()._save_generation_strategy_to_db_if_possible( diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 6df0aa5dd25..40f8b939d7c 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -29,6 +29,7 @@ import ax.service.utils.early_stopping as early_stopping_utils from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData from ax.core.map_metric import MapMetric @@ -68,7 +69,7 @@ set_stderr_log_level, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from ax.utils.common.typeutils import not_none +from ax.utils.common.typeutils import checked_cast, not_none from pyre_extensions import assert_is_instance @@ -161,7 +162,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin): """ experiment: Experiment - generation_strategy: GenerationStrategy + generation_strategy: GenerationStrategyInterface options: SchedulerOptions logger: LoggerAdapter # Mapping of form {short string identifier -> message to show in reported @@ -205,7 +206,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin): def __init__( self, experiment: Experiment, - generation_strategy: GenerationStrategy, + generation_strategy: GenerationStrategyInterface, options: SchedulerOptions, db_settings: Optional[DBSettings] = None, _skip_experiment_save: bool = False, @@ -220,7 +221,7 @@ def __init__( if not isinstance(experiment, Experiment): raise TypeError("{experiment} is not an Ax experiment.") - if not isinstance(generation_strategy, GenerationStrategy): + if not isinstance(generation_strategy, GenerationStrategyInterface): raise TypeError("{generation_strategy} is not a generation strategy.") self._validate_options(options=options) @@ -381,6 +382,19 @@ def runner(self) -> Runner: """ return not_none(self.experiment.runner) + @property + def concrete_generation_strategy(self) -> GenerationStrategy: + """Concrete ``GenerationStrategy`` associated with this ``Scheduler`` + instance. + """ + return checked_cast( + GenerationStrategy, + self.generation_strategy, + UnsupportedError( + "This functionality is only supported with a local GenerationStrategy" + ), + ) + def __repr__(self) -> str: """Short user-friendly string representation.""" if not hasattr(self, "experiment"): @@ -446,7 +460,7 @@ def get_best_trial( ) -> Optional[Tuple[int, TParameterization, Optional[TModelPredictArm]]]: return self._get_best_trial( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.concrete_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -461,7 +475,7 @@ def get_pareto_optimal_parameters( ) -> Optional[Dict[int, Tuple[TParameterization, TModelPredictArm]]]: return self._get_pareto_optimal_parameters( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.concrete_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -476,7 +490,7 @@ def get_hypervolume( ) -> float: return BestPointMixin._get_hypervolume( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.concrete_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -1449,12 +1463,9 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]: Returns: List of trials, empty if generation is not possible. """ - pending = get_pending_observation_features_based_on_trial_status( - experiment=self.experiment - ) try: generator_runs = self._gen_new_trials_from_generation_strategy( - num_trials=num_trials, n=n, pending=pending + num_trials=num_trials, n=n ) except OptimizationComplete as err: completion_str = f"Optimization complete: {err}" @@ -1483,9 +1494,9 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]: self.logger.debug(f"Message from generation strategy: {err}") return [] - if ( - self.options.trial_type == TrialType.TRIAL - and len(generator_runs[0].arms) > 1 + if self.options.trial_type == TrialType.TRIAL and any( + len(generator_run_list[0].arms) > 1 or len(generator_run_list) > 1 + for generator_run_list in generator_runs ): raise SchedulerInternalError( "Generation strategy produced multiple arms when only one was expected." @@ -1493,32 +1504,30 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]: return [ self.experiment.new_batch_trial( - generator_run=generator_run, + generator_runs=generator_run_list, ttl_seconds=self.options.ttl_seconds_for_trials, ) if self.options.trial_type == TrialType.BATCH_TRIAL else self.experiment.new_trial( - generator_run=generator_run, + generator_run=generator_run_list[0], ttl_seconds=self.options.ttl_seconds_for_trials, ) - for generator_run in generator_runs + for generator_run_list in generator_runs ] def _gen_new_trials_from_generation_strategy( self, num_trials: int, n: int, - pending: Optional[Dict[str, List[ObservationFeatures]]], - ) -> List[GeneratorRun]: + ) -> List[List[GeneratorRun]]: """Generates a list ``GeneratorRun``s of length of ``num_trials`` using the ``_gen_multiple`` method of the scheduler's ``generation_strategy``, taking into account any ``pending`` observations. """ - return self.generation_strategy._gen_multiple( + return self.generation_strategy.gen_for_multiple_trials_with_multiple_models( experiment=self.experiment, num_generator_runs=num_trials, n=n, - pending_observations=pending, ) def _update_and_save_trials( @@ -1853,7 +1862,7 @@ def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge: Returns: A ModelBridge object fitted to the observations of the scheduler's experiment. """ - gs = scheduler.generation_strategy # GenerationStrategy + gs = scheduler.concrete_generation_strategy model_bridge = gs.model # Optional[ModelBridge] if model_bridge is None: # Need to re-fit the model. data = scheduler.experiment.fetch_data() diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 7a589b00c20..db4c6947c78 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -350,9 +350,10 @@ def test_validate_early_stopping_strategy(self) -> None: ), ) - @patch( - f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple", - return_value=[get_generator_run()], + @patch.object( + GenerationStrategy, + "gen_for_multiple_trials_with_multiple_models", + return_value=[[get_generator_run()]], ) def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None: scheduler = Scheduler( @@ -366,7 +367,7 @@ def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None: @patch( # Record calls to function, but still execute it. ( - f"{Scheduler.__module__}." + f"{GenerationStrategy.__module__}." "get_pending_observation_features_based_on_trial_status" ), side_effect=get_pending_observation_features_based_on_trial_status, @@ -987,8 +988,9 @@ def test_base_report_results(self) -> None: ) self.assertEqual(scheduler.run_n_trials(max_trials=3), OptimizationResult()) - @patch( - f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple", + @patch.object( + GenerationStrategy, + "gen_for_multiple_trials_with_multiple_models", side_effect=OptimizationComplete("test error"), ) def test_optimization_complete(self, _) -> None: @@ -1327,7 +1329,7 @@ def test_get_fitted_model_bridge(self) -> None: with patch.object( GenerationStrategy, "_fit_current_model", - wraps=scheduler.generation_strategy._fit_current_model, + wraps=generation_strategy._fit_current_model, ) as fit_model: get_fitted_model_bridge(scheduler) fit_model.assert_called_once() diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index c4c55359b47..874a98cb8f8 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -12,6 +12,7 @@ from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.exceptions.core import ( IncompatibleDependencyVersion, @@ -155,7 +156,7 @@ def _get_experiment_and_generation_strategy_db_id( return exp_id, gs_id def _maybe_save_experiment_and_generation_strategy( - self, experiment: Experiment, generation_strategy: GenerationStrategy + self, experiment: Experiment, generation_strategy: GenerationStrategyInterface ) -> Tuple[bool, bool]: """If DB settings are set on this `WithDBSettingsBase` instance, checks whether given experiment and generation strategy are already saved and @@ -304,7 +305,7 @@ def _save_or_update_trials_and_generation_strategy_if_possible( self, experiment: Experiment, trials: List[BaseTrial], - generation_strategy: GenerationStrategy, + generation_strategy: GenerationStrategyInterface, new_generator_runs: List[GeneratorRun], reduce_state_generator_runs: bool = False, ) -> None: @@ -386,7 +387,7 @@ def _save_or_update_trials_in_db_if_possible( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: Optional[GenerationStrategy] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, suppress_all_errors: bool = False, ) -> bool: """Saves given generation strategy if DB settings are set on this @@ -399,18 +400,21 @@ def _save_generation_strategy_to_db_if_possible( bool: Whether the generation strategy was saved. """ if self.db_settings_set and generation_strategy is not None: - _save_generation_strategy_to_db_if_possible( - generation_strategy=generation_strategy, - encoder=self.db_settings.encoder, - decoder=self.db_settings.decoder, - suppress_all_errors=self._suppress_all_errors, - ) + # only local GenerationStrategies should need to be saved to + # the database because only they make changes locally + if isinstance(generation_strategy, GenerationStrategy): + _save_generation_strategy_to_db_if_possible( + generation_strategy=generation_strategy, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + ) return True return False def _update_generation_strategy_in_db_if_possible( self, - generation_strategy: GenerationStrategy, + generation_strategy: GenerationStrategyInterface, new_generator_runs: List[GeneratorRun], reduce_state_generator_runs: bool = False, ) -> bool: @@ -427,15 +431,18 @@ def _update_generation_strategy_in_db_if_possible( bool: Whether the experiment was saved. """ if self.db_settings_set: - _update_generation_strategy_in_db_if_possible( - generation_strategy=generation_strategy, - new_generator_runs=new_generator_runs, - encoder=self.db_settings.encoder, - decoder=self.db_settings.decoder, - suppress_all_errors=self._suppress_all_errors, - reduce_state_generator_runs=reduce_state_generator_runs, - ) - return True + # only local GenerationStrategies should need to be saved to + # the database because only they make changes locally + if isinstance(generation_strategy, GenerationStrategy): + _update_generation_strategy_in_db_if_possible( + generation_strategy=generation_strategy, + new_generator_runs=new_generator_runs, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + reduce_state_generator_runs=reduce_state_generator_runs, + ) + return True return False def _update_experiment_properties_in_db( diff --git a/ax/telemetry/generation_strategy.py b/ax/telemetry/generation_strategy.py index b510e63d5f9..4b9c95f9cee 100644 --- a/ax/telemetry/generation_strategy.py +++ b/ax/telemetry/generation_strategy.py @@ -9,7 +9,6 @@ from math import inf from ax.modelbridge.generation_strategy import GenerationStrategy - from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS diff --git a/ax/telemetry/scheduler.py b/ax/telemetry/scheduler.py index a51491f456d..57cb9a6e2d3 100644 --- a/ax/telemetry/scheduler.py +++ b/ax/telemetry/scheduler.py @@ -49,7 +49,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord: ), generation_strategy_created_record=( GenerationStrategyCreatedRecord.from_generation_strategy( - generation_strategy=scheduler.generation_strategy + generation_strategy=scheduler.concrete_generation_strategy ) ), scheduler_total_trials=scheduler.options.total_trials, @@ -68,7 +68,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord: ), transformed_dimensionality=_get_max_transformed_dimensionality( search_space=scheduler.experiment.search_space, - generation_strategy=scheduler.generation_strategy, + generation_strategy=scheduler.concrete_generation_strategy, ), ) diff --git a/ax/telemetry/tests/test_scheduler.py b/ax/telemetry/tests/test_scheduler.py index f46fedbaac1..eec5b85c1aa 100644 --- a/ax/telemetry/tests/test_scheduler.py +++ b/ax/telemetry/tests/test_scheduler.py @@ -45,7 +45,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None: ), generation_strategy_created_record=( GenerationStrategyCreatedRecord.from_generation_strategy( - generation_strategy=scheduler.generation_strategy + generation_strategy=scheduler.concrete_generation_strategy ) ), scheduler_total_trials=0, @@ -63,7 +63,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None: experiment=scheduler.experiment ).__dict__, **GenerationStrategyCreatedRecord.from_generation_strategy( - generation_strategy=scheduler.generation_strategy + generation_strategy=scheduler.concrete_generation_strategy ).__dict__, "scheduler_total_trials": 0, "scheduler_max_pending_trials": 10,