diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index f3b71dd2f16..e3e43d74b1e 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -26,7 +26,10 @@ from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.utils import get_pending_observation_features_based_on_trial_status -from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies import ( + BaseEarlyStoppingStrategy, + PercentileEarlyStoppingStrategy, +) from ax.exceptions.core import UnsupportedError, UserInputError from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStrategy @@ -55,9 +58,9 @@ class Client: - _experiment: Experiment | None = None - _generation_strategy: GenerationStrategy | None = None - _early_stopping_strategy: BaseEarlyStoppingStrategy | None = None + _maybe_experiment: Experiment | None = None + _maybe_generation_strategy: GenerationStrategy | None = None + _maybe_early_stopping_strategy: BaseEarlyStoppingStrategy | None = None def __init__( self, @@ -89,13 +92,13 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None: Saves to database on completion if db_config is present. """ - if self._experiment is not None: + if self._maybe_experiment is not None: raise UnsupportedError( "Experiment already configured. Please create a new Client if you " "would like a new experiment." ) - self._experiment = experiment_from_config(config=experiment_config) + self._maybe_experiment = experiment_from_config(config=experiment_config) if self._db_config is not None: # TODO[mpolson64] Save to database @@ -137,11 +140,9 @@ def configure_optimization( Saves to database on completion if db_config is present. """ - self._none_throws_experiment().optimization_config = ( - optimization_config_from_string( - objective_str=objective, - outcome_constraint_strs=outcome_constraints, - ) + self._experiment.optimization_config = optimization_config_from_string( + objective_str=objective, + outcome_constraint_strs=outcome_constraints, ) if self._db_config is not None: @@ -159,8 +160,8 @@ def configure_generation_strategy( """ generation_strategy = choose_generation_strategy( - search_space=self._none_throws_experiment().search_space, - optimization_config=self._none_throws_experiment().optimization_config, + search_space=self._experiment.search_space, + optimization_config=self._experiment.optimization_config, num_trials=generation_strategy_config.num_trials, num_initialization_trials=( generation_strategy_config.num_initialization_trials @@ -170,9 +171,9 @@ def configure_generation_strategy( ) # Necessary for storage implications, may be removed in the future - generation_strategy._experiment = self._none_throws_experiment() + generation_strategy._experiment = self._experiment - self._generation_strategy = generation_strategy + self._maybe_generation_strategy = generation_strategy if self._db_config is not None: # TODO[mpolson64] Save to database @@ -207,7 +208,7 @@ def set_experiment(self, experiment: Experiment) -> None: Saves to database on completion if db_config is present. """ - self._experiment = experiment + self._maybe_experiment = experiment if self._db_config is not None: # TODO[mpolson64] Save to database @@ -223,7 +224,7 @@ def set_optimization_config(self, optimization_config: OptimizationConfig) -> No Saves to database on completion if db_config is present. """ - self._none_throws_experiment().optimization_config = optimization_config + self._experiment.optimization_config = optimization_config if self._db_config is not None: # TODO[mpolson64] Save to database @@ -239,14 +240,9 @@ def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> No Saves to database on completion if db_config is present. """ - self._generation_strategy = generation_strategy - none_throws( - self._generation_strategy - )._experiment = self._none_throws_experiment() + self._maybe_generation_strategy = generation_strategy - none_throws( - self._generation_strategy - )._experiment = self._none_throws_experiment() + none_throws(self._maybe_generation_strategy)._experiment = self._experiment if self._db_config is not None: # TODO[mpolson64] Save to database @@ -265,7 +261,7 @@ def set_early_stopping_strategy( Saves to database on completion if db_config is present. """ - self._early_stopping_strategy = early_stopping_strategy + self._maybe_early_stopping_strategy = early_stopping_strategy if self._db_config is not None: # TODO[mpolson64] Save to database @@ -281,7 +277,7 @@ def _set_runner(self, runner: Runner) -> None: Saves to database on completion if db_config is present. """ - self._none_throws_experiment().runner = runner + self._experiment.runner = runner if self._db_config is not None: # TODO[mpolson64] Save to database @@ -328,7 +324,7 @@ def get_next_trials( A mapping of trial index to parameterization. """ - if self._none_throws_experiment().optimization_config is None: + if self._experiment.optimization_config is None: raise UnsupportedError( "OptimizationConfig not set. Please call configure_optimization before " "generating trials." @@ -340,10 +336,10 @@ def get_next_trials( # This will be changed to use gen directly post gen-unfication cc @mgarrard generator_runs = gs.gen_for_multiple_trials_with_multiple_models( - experiment=self._none_throws_experiment(), + experiment=self._experiment, pending_observations=( get_pending_observation_features_based_on_trial_status( - experiment=self._none_throws_experiment() + experiment=self._experiment ) ), n=1, @@ -359,7 +355,7 @@ def get_next_trials( for generator_run in generator_runs: trial = assert_is_instance( - self._none_throws_experiment().new_trial( + self._experiment.new_trial( generator_run=generator_run[0], ), Trial, @@ -398,20 +394,18 @@ def complete_trial( trial_index=trial_index, raw_data=raw_data, progression=progression ) - experiment = self._none_throws_experiment() - # If no OptimizationConfig is set, mark the trial as COMPLETED - if (optimization_config := experiment.optimization_config) is None: - experiment.trials[trial_index].mark_completed() + if (optimization_config := self._experiment.optimization_config) is None: + self._experiment.trials[trial_index].mark_completed() else: - trial_data = experiment.lookup_data(trial_indices=[trial_index]) + trial_data = self._experiment.lookup_data(trial_indices=[trial_index]) missing_metrics = {*optimization_config.metrics.keys()} - { *trial_data.metric_names } # If all necessary metrics are present mark the trial as COMPLETED if len(missing_metrics) == 0: - experiment.trials[trial_index].mark_completed() + self._experiment.trials[trial_index].mark_completed() # If any metrics are missing mark the trial as FAILED else: @@ -425,7 +419,7 @@ def complete_trial( # TODO[mpolson64] Save trial ... - return experiment.trials[trial_index].status + return self._experiment.trials[trial_index].status def attach_data( self, @@ -448,9 +442,7 @@ def attach_data( ({"step": progression if progression is not None else np.nan}, raw_data) ] - trial = assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial - ) + trial = assert_is_instance(self._experiment.trials[trial_index], Trial) trial.update_trial_data( # pyre-fixme[6]: Type narrowing broken because core Ax TParameterization # is dict not Mapping @@ -476,7 +468,7 @@ def attach_trial( Returns: The index of the attached trial. """ - _, trial_index = self._none_throws_experiment().attach_trial( + _, trial_index = self._experiment.attach_trial( # pyre-fixme[6]: Type narrowing broken because core Ax TParameterization # is dict not Mapping parameterizations=[parameters], @@ -508,8 +500,8 @@ def attach_baseline( arm_name=arm_name or "baseline", ) - self._none_throws_experiment().status_quo = assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial + self._experiment.status_quo = assert_is_instance( + self._experiment.trials[trial_index], Trial ).arm if self._db_config is not None: @@ -528,18 +520,11 @@ def should_stop_trial_early(self, trial_index: int) -> bool: Returns: Whether the trial should be stopped early. """ - if self._early_stopping_strategy is None: - # In the future we may want to support inferring a default early stopping - # strategy - raise UnsupportedError( - "Early stopping strategy not set. Please set an early stopping " - "strategy before calling should_stop_trial_early." - ) es_response = none_throws( - self._early_stopping_strategy + self._early_stopping_strategy_or_choose() ).should_stop_trials_early( - trial_indices={trial_index}, experiment=self._none_throws_experiment() + trial_indices={trial_index}, experiment=self._experiment ) # TODO[mpolson64]: log the returned reason for stopping the trial @@ -553,7 +538,7 @@ def mark_trial_failed(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - self._none_throws_experiment().trials[trial_index].mark_failed() + self._experiment.trials[trial_index].mark_failed() if self._db_config is not None: # TODO[mpolson64] Save to database @@ -567,7 +552,7 @@ def mark_trial_abandoned(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - self._none_throws_experiment().trials[trial_index].mark_abandoned() + self._experiment.trials[trial_index].mark_abandoned() if self._db_config is not None: # TODO[mpolson64] Save to database @@ -590,7 +575,7 @@ def mark_trial_early_stopped( trial_index=trial_index, raw_data=raw_data, progression=progression ) - self._none_throws_experiment().trials[trial_index].mark_early_stopped() + self._experiment.trials[trial_index].mark_early_stopped() if self._db_config is not None: # TODO[mpolson64] Save to database @@ -606,8 +591,8 @@ def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: """ scheduler = Scheduler( - experiment=self._none_throws_experiment(), - generation_strategy=(self._generation_strategy_or_choose()), + experiment=self._experiment, + generation_strategy=self._generation_strategy_or_choose(), options=SchedulerOptions( max_pending_trials=options.parallelism, tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, @@ -643,15 +628,15 @@ def compute_analyses( analyses = ( analyses if analyses is not None - else choose_analyses(experiment=self._none_throws_experiment()) + else choose_analyses(experiment=self._experiment) ) # Compute Analyses one by one and accumulate Results holding either the # AnalysisCard or an Exception and some metadata results = [ analysis.compute_result( - experiment=self._none_throws_experiment(), - generation_strategy=self._generation_strategy_or_choose(), + experiment=self._experiment, + generation_strategy=self._generation_strategy, ) for analysis in analyses ] @@ -689,7 +674,7 @@ def get_best_arm( use_model_predictions=True, otherwise returns observed data. """ - if len(self._none_throws_experiment().trials) < 1: + if len(self._experiment.trials) < 1: raise UnsupportedError( "No trials have been run yet. Please run at least one trial before " "calling get_best_arm." @@ -699,20 +684,14 @@ def get_best_arm( # unwanted public methods trial_index, parameters, _ = none_throws( BestPointMixin._get_best_trial( - experiment=self._none_throws_experiment(), - # Requiring true GenerationStrategy here, ideally we will loosen this - # in the future - generation_strategy=assert_is_instance( - self._generation_strategy_or_choose(), GenerationStrategy - ), + experiment=self._experiment, + generation_strategy=self._generation_strategy, use_model_predictions=use_model_predictions, ) ) arm = none_throws( - assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial - ).arm + assert_is_instance(self._experiment.trials[trial_index], Trial).arm ) if use_model_predictions: @@ -727,11 +706,9 @@ def get_best_arm( prediction = {} else: - data_dict = ( - self._none_throws_experiment() - .lookup_data(trial_indices=[trial_index]) - .df.to_dict() - ) + data_dict = self._experiment.lookup_data( + trial_indices=[trial_index] + ).df.to_dict() prediction = { data_dict["metric_name"][i]: (data_dict["mean"][i], data_dict["sem"][i]) @@ -753,19 +730,17 @@ def get_pareto_frontier( outcome. """ - if len(self._none_throws_experiment().trials) < 1: + if len(self._experiment.trials) < 1: raise UnsupportedError( "No trials have been run yet. Please run at least one trial before " "calling get_pareto_frontier." ) frontier = BestPointMixin._get_pareto_optimal_parameters( - experiment=self._none_throws_experiment(), + experiment=self._experiment, # Requiring true GenerationStrategy here, ideally we will loosen this # in the future - generation_strategy=assert_is_instance( - self._generation_strategy_or_choose(), GenerationStrategy - ), + generation_strategy=self._generation_strategy, use_model_predictions=use_model_predictions, ) @@ -773,9 +748,7 @@ def get_pareto_frontier( arm_names = [ none_throws( - assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial - ).arm + assert_is_instance(self._experiment.trials[trial_index], Trial).arm ).name for trial_index, _ in frontier_list ] @@ -796,11 +769,9 @@ def get_pareto_frontier( else: predictions = [] for trial_index in frontier.keys(): - data_dict = ( - self._none_throws_experiment() - .lookup_data(trial_indices=[trial_index]) - .df.to_dict() - ) + data_dict = self._experiment.lookup_data( + trial_indices=[trial_index] + ).df.to_dict() predictions.append( { @@ -842,9 +813,8 @@ def predict( Returns: A list of mappings from metric name to predicted mean and SEM """ - search_space = self._none_throws_experiment().search_space for parameters in points: - search_space.check_membership( + self._experiment.search_space.check_membership( # pyre-fixme[6]: Core Ax allows users to specify TParameterization # values as None but we do not allow this in the API. parameterization=parameters, @@ -852,10 +822,8 @@ def predict( check_all_parameters_present=True, ) - generation_strategy = self._generation_strategy_or_choose() - try: - mean, covariance = none_throws(generation_strategy.model).predict( + mean, covariance = none_throws(self._generation_strategy.model).predict( observation_features=[ # pyre-fixme[6]: Core Ax allows users to specify TParameterization # values as None but we do not allow this in the API. @@ -916,15 +884,39 @@ def load_from_database( """ ... - def _none_throws_experiment(self) -> Experiment: + # -------------------- Section 5: Private Methods ------------------------------- + # -------------------- Section 5.1: Getters and defaults ------------------------ + @property + def _experiment(self) -> Experiment: return none_throws( - self._experiment, + self._maybe_experiment, ( "Experiment not set. Please call configure_experiment or load an " "experiment before utilizing any other methods on the Client." ), ) + @property + def _generation_strategy(self) -> GenerationStrategy: + return none_throws( + self._maybe_generation_strategy, + ( + "GenerationStrategy not set. Please call " + "configure_generation_strategy, load a GenerationStrategy, or call " + "get_next_trials or run_trials to automatically choose a " + "GenerationStrategy before utilizing any other methods on the Client " + "which require one." + ), + ) + + @property + def _early_stopping_strategy(self) -> BaseEarlyStoppingStrategy: + return none_throws( + self._maybe_early_stopping_strategy, + "Early stopping strategy not set. Please set an early stopping strategy " + "before calling should_stop_trial_early.", + ) + def _generation_strategy_or_choose( self, ) -> GenerationStrategy: @@ -933,12 +925,33 @@ def _generation_strategy_or_choose( return it. """ - if self._generation_strategy is None: + try: + return self._generation_strategy + except AssertionError: self.configure_generation_strategy( generation_strategy_config=GenerationStrategyConfig() ) - return none_throws(self._generation_strategy) + return self._generation_strategy + + def _early_stopping_strategy_or_choose( + self, + ) -> BaseEarlyStoppingStrategy: + """ + If an EarlyStoppingStrategy is not set choose a default one and return it. + """ + + try: + return self._early_stopping_strategy + except AssertionError: + # PercetinleEarlyStoppingStrategy may or may not have sensible defaults at + # current moment -- we will need to be critical of these settings during + # benchmarking + self.set_early_stopping_strategy( + early_stopping_strategy=PercentileEarlyStoppingStrategy() + ) + + return self._early_stopping_strategy def _overwrite_metric(self, metric: Metric) -> None: """ @@ -951,9 +964,7 @@ def _overwrite_metric(self, metric: Metric) -> None: """ # Check the OptimizationConfig first - if ( - optimization_config := self._none_throws_experiment().optimization_config - ) is not None: + if (optimization_config := self._experiment.optimization_config) is not None: # Check the objective if isinstance( multi_objective := optimization_config.objective, MultiObjective @@ -989,14 +1000,13 @@ def _overwrite_metric(self, metric: Metric) -> None: return # Check the tracking metrics - tracking_metric_names = self._none_throws_experiment()._tracking_metrics.keys() - if metric.name in tracking_metric_names: - self._none_throws_experiment()._tracking_metrics[metric.name] = metric + if metric.name in self._experiment._tracking_metrics.keys(): + self._experiment._tracking_metrics[metric.name] = metric return # If an equivalently named Metric does not exist, add it as a tracking # metric. - self._none_throws_experiment().add_tracking_metric(metric=metric) + self._experiment.add_tracking_metric(metric=metric) logger.warning( f"Metric {metric} not found in optimization config, added as tracking " "metric." diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 6df01eb8446..804bde92baa 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -158,7 +158,7 @@ def test_configure_optimization(self) -> None: ) self.assertEqual( - none_throws(client._experiment).optimization_config, + client._experiment.optimization_config, OptimizationConfig( objective=Objective(metric=Metric(name="ne"), minimize=True), outcome_constraints=[ @@ -190,7 +190,7 @@ def test_configure_runner(self) -> None: client.set_experiment(experiment=get_branin_experiment()) client.configure_runner(runner=runner) - self.assertEqual(none_throws(client._experiment).runner, runner) + self.assertEqual(client._experiment.runner, runner) def test_configure_metric(self) -> None: client = Client() @@ -216,9 +216,7 @@ def test_configure_metric(self) -> None: self.assertEqual( custom_metric, - none_throws( - none_throws(client._experiment).optimization_config - ).objective.metric, + none_throws(client._experiment.optimization_config).objective.metric, ) # Test replacing a multi-objective @@ -228,9 +226,7 @@ def test_configure_metric(self) -> None: self.assertIn( custom_metric, assert_is_instance( - none_throws( - none_throws(client._experiment).optimization_config - ).objective, + none_throws(client._experiment.optimization_config).objective, MultiObjective, ).metrics, ) @@ -241,9 +237,7 @@ def test_configure_metric(self) -> None: self.assertIn( custom_metric, assert_is_instance( - none_throws( - none_throws(client._experiment).optimization_config - ).objective, + none_throws(client._experiment.optimization_config).objective, ScalarizedObjective, ).metrics, ) @@ -256,7 +250,7 @@ def test_configure_metric(self) -> None: self.assertEqual( custom_metric, - none_throws(none_throws(client._experiment).optimization_config) + none_throws(client._experiment.optimization_config) .outcome_constraints[0] .metric, ) @@ -265,12 +259,12 @@ def test_configure_metric(self) -> None: client.configure_optimization( objective="foo", ) - none_throws(client._experiment).add_tracking_metric(metric=Metric("custom")) + client._experiment.add_tracking_metric(metric=Metric("custom")) client.configure_metrics(metrics=[custom_metric]) self.assertEqual( custom_metric, - none_throws(client._experiment).tracking_metrics[0], + client._experiment.tracking_metrics[0], ) # Test adding a tracking metric @@ -289,7 +283,7 @@ def test_configure_metric(self) -> None: self.assertEqual( custom_metric, - none_throws(client._experiment).tracking_metrics[0], + client._experiment.tracking_metrics[0], ) def test_set_experiment(self) -> None: @@ -312,9 +306,7 @@ def test_set_optimization_config(self) -> None: optimization_config=optimization_config, ) - self.assertEqual( - none_throws(client._experiment).optimization_config, optimization_config - ) + self.assertEqual(client._experiment.optimization_config, optimization_config) def test_set_generation_strategy(self) -> None: client = Client() @@ -413,14 +405,12 @@ def test_attach_data(self) -> None: client.attach_data(trial_index=trial_index, raw_data={"foo": 1.0}) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.RUNNING, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -440,14 +430,12 @@ def test_attach_data(self) -> None: client.attach_data(trial_index=0, raw_data={"foo": 2.0}, progression=10) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.RUNNING, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -469,14 +457,12 @@ def test_attach_data(self) -> None: raw_data={"foo": 1.0, "bar": 2.0}, ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.RUNNING, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -517,14 +503,12 @@ def test_complete_trial(self) -> None: ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.COMPLETED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -547,15 +531,13 @@ def test_complete_trial(self) -> None: ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.COMPLETED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -576,14 +558,12 @@ def test_complete_trial(self) -> None: client.complete_trial(trial_index=trial_index, raw_data={"foo": 1.0}) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.FAILED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -618,9 +598,7 @@ def test_attach_trial(self) -> None: client.configure_optimization(objective="foo") trial_index = client.attach_trial(parameters={"x1": 0.5}, arm_name="bar") - trial = assert_is_instance( - none_throws(client._experiment).trials[trial_index], Trial - ) + trial = assert_is_instance(client._experiment.trials[trial_index], Trial) self.assertEqual(none_throws(trial.arm).parameters, {"x1": 0.5}) self.assertEqual(none_throws(trial.arm).name, "bar") self.assertEqual(trial.status, TrialStatus.RUNNING) @@ -644,14 +622,12 @@ def test_attach_baseline(self) -> None: client.configure_optimization(objective="foo") trial_index = client.attach_baseline(parameters={"x1": 0.5}) - trial = assert_is_instance( - none_throws(client._experiment).trials[trial_index], Trial - ) + trial = assert_is_instance(client._experiment.trials[trial_index], Trial) self.assertEqual(none_throws(trial.arm).parameters, {"x1": 0.5}) self.assertEqual(none_throws(trial.arm).name, "baseline") self.assertEqual(trial.status, TrialStatus.RUNNING) - self.assertEqual(client._none_throws_experiment().status_quo, trial.arm) + self.assertEqual(client._experiment.status_quo, trial.arm) def test_mark_trial_failed(self) -> None: client = Client() @@ -671,7 +647,7 @@ def test_mark_trial_failed(self) -> None: trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] client.mark_trial_failed(trial_index=trial_index) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.FAILED, ) @@ -693,7 +669,7 @@ def test_mark_trial_abandoned(self) -> None: trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] client.mark_trial_abandoned(trial_index=trial_index) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.ABANDONED, ) @@ -717,14 +693,12 @@ def test_mark_trial_early_stopped(self) -> None: trial_index=trial_index, raw_data={"foo": 0.0}, progression=1 ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.EARLY_STOPPED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -755,11 +729,6 @@ def test_should_stop_trial_early(self) -> None: ) client.configure_optimization(objective="foo") - with self.assertRaisesRegex( - UnsupportedError, "Early stopping strategy not set" - ): - client.should_stop_trial_early(trial_index=0) - client.set_early_stopping_strategy( early_stopping_strategy=PercentileEarlyStoppingStrategy( metric_names=["foo"] @@ -788,20 +757,18 @@ def test_run_trials(self) -> None: client.run_trials(maximum_trials=4, options=OrchestrationConfig()) - self.assertEqual(len(client._none_throws_experiment().trials), 4) + self.assertEqual(len(client._experiment.trials), 4) self.assertEqual( [ trial.index - for trial in client._none_throws_experiment().trials_by_status[ - TrialStatus.COMPLETED - ] + for trial in client._experiment.trials_by_status[TrialStatus.COMPLETED] ], [0, 1, 2, 3], ) self.assertTrue( assert_is_instance( - client._none_throws_experiment().lookup_data(), + client._experiment.lookup_data(), MapData, ).map_df.equals( pd.DataFrame( @@ -843,13 +810,11 @@ def test_get_next_trials_then_run_trials(self) -> None: _ = client.get_next_trials(maximum_trials=1) self.assertEqual( - len( - client._none_throws_experiment().trials_by_status[TrialStatus.COMPLETED] - ), + len(client._experiment.trials_by_status[TrialStatus.COMPLETED]), 2, ) self.assertEqual( - len(client._none_throws_experiment().trials_by_status[TrialStatus.RUNNING]), + len(client._experiment.trials_by_status[TrialStatus.RUNNING]), 1, ) @@ -860,9 +825,7 @@ def test_get_next_trials_then_run_trials(self) -> None: # All trials should be COMPLETED self.assertEqual( - len( - client._none_throws_experiment().trials_by_status[TrialStatus.COMPLETED] - ), + len(client._experiment.trials_by_status[TrialStatus.COMPLETED]), 5, ) @@ -880,6 +843,9 @@ def test_compute_analyses(self) -> None: ) ) client.configure_optimization(objective="foo") + client.configure_generation_strategy( + generation_strategy_config=GenerationStrategyConfig() + ) with self.assertLogs(logger="ax.analysis", level="ERROR") as lg: cards = client.compute_analyses(analyses=[ParallelCoordinatesPlot()]) @@ -950,11 +916,11 @@ def test_get_best_arm(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) ) @@ -973,11 +939,11 @@ def test_get_best_arm(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) ) @@ -1026,11 +992,11 @@ def test_get_pareto_frontier(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) ) @@ -1054,11 +1020,11 @@ def test_get_pareto_frontier(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) )