From 407668ff839c5e12f0ccd012c8ed0ca11ac5001e Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Mon, 9 Dec 2024 08:00:03 -0800 Subject: [PATCH] Improve handling around optional attrs (#3151) Summary: After disccusion offline with lena-kashtelyan we decided that _none_throws_experiment() is a little verbose and it would be easier for developers to have a simple property with a short name to access attributes that may or may not be set on the Client. That being said we wanted to still keep them private as to not imply that it was supported behavior for users to call these functions and rely on their outputs' structure not changing. With that in mind we're making the following changes: 1. All optional attrs are prefixed with _maybe ex. _experiment becomes _maybe_experiment 2. All optional attrs get a @/property where the maybe is dropped and none_throws is called 3. _generation_strategy and _early_stopping_strategy get special methods that either return either return the existing attr or instantiate some default, call client.set, and return. This is important so a user that is fine with defaults will not have to call client.configure_generation_strategy(GenerationStrategyConfig()) manually Reviewed By: lena-kashtelyan Differential Revision: D66833556 --- ax/preview/api/client.py | 218 +++++++++++++++------------- ax/preview/api/tests/test_client.py | 126 ++++++---------- 2 files changed, 160 insertions(+), 184 deletions(-) diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index f3b71dd2f16..e3e43d74b1e 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -26,7 +26,10 @@ from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.utils import get_pending_observation_features_based_on_trial_status -from ax.early_stopping.strategies import BaseEarlyStoppingStrategy +from ax.early_stopping.strategies import ( + BaseEarlyStoppingStrategy, + PercentileEarlyStoppingStrategy, +) from ax.exceptions.core import UnsupportedError, UserInputError from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStrategy @@ -55,9 +58,9 @@ class Client: - _experiment: Experiment | None = None - _generation_strategy: GenerationStrategy | None = None - _early_stopping_strategy: BaseEarlyStoppingStrategy | None = None + _maybe_experiment: Experiment | None = None + _maybe_generation_strategy: GenerationStrategy | None = None + _maybe_early_stopping_strategy: BaseEarlyStoppingStrategy | None = None def __init__( self, @@ -89,13 +92,13 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None: Saves to database on completion if db_config is present. """ - if self._experiment is not None: + if self._maybe_experiment is not None: raise UnsupportedError( "Experiment already configured. Please create a new Client if you " "would like a new experiment." ) - self._experiment = experiment_from_config(config=experiment_config) + self._maybe_experiment = experiment_from_config(config=experiment_config) if self._db_config is not None: # TODO[mpolson64] Save to database @@ -137,11 +140,9 @@ def configure_optimization( Saves to database on completion if db_config is present. """ - self._none_throws_experiment().optimization_config = ( - optimization_config_from_string( - objective_str=objective, - outcome_constraint_strs=outcome_constraints, - ) + self._experiment.optimization_config = optimization_config_from_string( + objective_str=objective, + outcome_constraint_strs=outcome_constraints, ) if self._db_config is not None: @@ -159,8 +160,8 @@ def configure_generation_strategy( """ generation_strategy = choose_generation_strategy( - search_space=self._none_throws_experiment().search_space, - optimization_config=self._none_throws_experiment().optimization_config, + search_space=self._experiment.search_space, + optimization_config=self._experiment.optimization_config, num_trials=generation_strategy_config.num_trials, num_initialization_trials=( generation_strategy_config.num_initialization_trials @@ -170,9 +171,9 @@ def configure_generation_strategy( ) # Necessary for storage implications, may be removed in the future - generation_strategy._experiment = self._none_throws_experiment() + generation_strategy._experiment = self._experiment - self._generation_strategy = generation_strategy + self._maybe_generation_strategy = generation_strategy if self._db_config is not None: # TODO[mpolson64] Save to database @@ -207,7 +208,7 @@ def set_experiment(self, experiment: Experiment) -> None: Saves to database on completion if db_config is present. """ - self._experiment = experiment + self._maybe_experiment = experiment if self._db_config is not None: # TODO[mpolson64] Save to database @@ -223,7 +224,7 @@ def set_optimization_config(self, optimization_config: OptimizationConfig) -> No Saves to database on completion if db_config is present. """ - self._none_throws_experiment().optimization_config = optimization_config + self._experiment.optimization_config = optimization_config if self._db_config is not None: # TODO[mpolson64] Save to database @@ -239,14 +240,9 @@ def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> No Saves to database on completion if db_config is present. """ - self._generation_strategy = generation_strategy - none_throws( - self._generation_strategy - )._experiment = self._none_throws_experiment() + self._maybe_generation_strategy = generation_strategy - none_throws( - self._generation_strategy - )._experiment = self._none_throws_experiment() + none_throws(self._maybe_generation_strategy)._experiment = self._experiment if self._db_config is not None: # TODO[mpolson64] Save to database @@ -265,7 +261,7 @@ def set_early_stopping_strategy( Saves to database on completion if db_config is present. """ - self._early_stopping_strategy = early_stopping_strategy + self._maybe_early_stopping_strategy = early_stopping_strategy if self._db_config is not None: # TODO[mpolson64] Save to database @@ -281,7 +277,7 @@ def _set_runner(self, runner: Runner) -> None: Saves to database on completion if db_config is present. """ - self._none_throws_experiment().runner = runner + self._experiment.runner = runner if self._db_config is not None: # TODO[mpolson64] Save to database @@ -328,7 +324,7 @@ def get_next_trials( A mapping of trial index to parameterization. """ - if self._none_throws_experiment().optimization_config is None: + if self._experiment.optimization_config is None: raise UnsupportedError( "OptimizationConfig not set. Please call configure_optimization before " "generating trials." @@ -340,10 +336,10 @@ def get_next_trials( # This will be changed to use gen directly post gen-unfication cc @mgarrard generator_runs = gs.gen_for_multiple_trials_with_multiple_models( - experiment=self._none_throws_experiment(), + experiment=self._experiment, pending_observations=( get_pending_observation_features_based_on_trial_status( - experiment=self._none_throws_experiment() + experiment=self._experiment ) ), n=1, @@ -359,7 +355,7 @@ def get_next_trials( for generator_run in generator_runs: trial = assert_is_instance( - self._none_throws_experiment().new_trial( + self._experiment.new_trial( generator_run=generator_run[0], ), Trial, @@ -398,20 +394,18 @@ def complete_trial( trial_index=trial_index, raw_data=raw_data, progression=progression ) - experiment = self._none_throws_experiment() - # If no OptimizationConfig is set, mark the trial as COMPLETED - if (optimization_config := experiment.optimization_config) is None: - experiment.trials[trial_index].mark_completed() + if (optimization_config := self._experiment.optimization_config) is None: + self._experiment.trials[trial_index].mark_completed() else: - trial_data = experiment.lookup_data(trial_indices=[trial_index]) + trial_data = self._experiment.lookup_data(trial_indices=[trial_index]) missing_metrics = {*optimization_config.metrics.keys()} - { *trial_data.metric_names } # If all necessary metrics are present mark the trial as COMPLETED if len(missing_metrics) == 0: - experiment.trials[trial_index].mark_completed() + self._experiment.trials[trial_index].mark_completed() # If any metrics are missing mark the trial as FAILED else: @@ -425,7 +419,7 @@ def complete_trial( # TODO[mpolson64] Save trial ... - return experiment.trials[trial_index].status + return self._experiment.trials[trial_index].status def attach_data( self, @@ -448,9 +442,7 @@ def attach_data( ({"step": progression if progression is not None else np.nan}, raw_data) ] - trial = assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial - ) + trial = assert_is_instance(self._experiment.trials[trial_index], Trial) trial.update_trial_data( # pyre-fixme[6]: Type narrowing broken because core Ax TParameterization # is dict not Mapping @@ -476,7 +468,7 @@ def attach_trial( Returns: The index of the attached trial. """ - _, trial_index = self._none_throws_experiment().attach_trial( + _, trial_index = self._experiment.attach_trial( # pyre-fixme[6]: Type narrowing broken because core Ax TParameterization # is dict not Mapping parameterizations=[parameters], @@ -508,8 +500,8 @@ def attach_baseline( arm_name=arm_name or "baseline", ) - self._none_throws_experiment().status_quo = assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial + self._experiment.status_quo = assert_is_instance( + self._experiment.trials[trial_index], Trial ).arm if self._db_config is not None: @@ -528,18 +520,11 @@ def should_stop_trial_early(self, trial_index: int) -> bool: Returns: Whether the trial should be stopped early. """ - if self._early_stopping_strategy is None: - # In the future we may want to support inferring a default early stopping - # strategy - raise UnsupportedError( - "Early stopping strategy not set. Please set an early stopping " - "strategy before calling should_stop_trial_early." - ) es_response = none_throws( - self._early_stopping_strategy + self._early_stopping_strategy_or_choose() ).should_stop_trials_early( - trial_indices={trial_index}, experiment=self._none_throws_experiment() + trial_indices={trial_index}, experiment=self._experiment ) # TODO[mpolson64]: log the returned reason for stopping the trial @@ -553,7 +538,7 @@ def mark_trial_failed(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - self._none_throws_experiment().trials[trial_index].mark_failed() + self._experiment.trials[trial_index].mark_failed() if self._db_config is not None: # TODO[mpolson64] Save to database @@ -567,7 +552,7 @@ def mark_trial_abandoned(self, trial_index: int) -> None: Saves to database on completion if db_config is present. """ - self._none_throws_experiment().trials[trial_index].mark_abandoned() + self._experiment.trials[trial_index].mark_abandoned() if self._db_config is not None: # TODO[mpolson64] Save to database @@ -590,7 +575,7 @@ def mark_trial_early_stopped( trial_index=trial_index, raw_data=raw_data, progression=progression ) - self._none_throws_experiment().trials[trial_index].mark_early_stopped() + self._experiment.trials[trial_index].mark_early_stopped() if self._db_config is not None: # TODO[mpolson64] Save to database @@ -606,8 +591,8 @@ def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: """ scheduler = Scheduler( - experiment=self._none_throws_experiment(), - generation_strategy=(self._generation_strategy_or_choose()), + experiment=self._experiment, + generation_strategy=self._generation_strategy_or_choose(), options=SchedulerOptions( max_pending_trials=options.parallelism, tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, @@ -643,15 +628,15 @@ def compute_analyses( analyses = ( analyses if analyses is not None - else choose_analyses(experiment=self._none_throws_experiment()) + else choose_analyses(experiment=self._experiment) ) # Compute Analyses one by one and accumulate Results holding either the # AnalysisCard or an Exception and some metadata results = [ analysis.compute_result( - experiment=self._none_throws_experiment(), - generation_strategy=self._generation_strategy_or_choose(), + experiment=self._experiment, + generation_strategy=self._generation_strategy, ) for analysis in analyses ] @@ -689,7 +674,7 @@ def get_best_arm( use_model_predictions=True, otherwise returns observed data. """ - if len(self._none_throws_experiment().trials) < 1: + if len(self._experiment.trials) < 1: raise UnsupportedError( "No trials have been run yet. Please run at least one trial before " "calling get_best_arm." @@ -699,20 +684,14 @@ def get_best_arm( # unwanted public methods trial_index, parameters, _ = none_throws( BestPointMixin._get_best_trial( - experiment=self._none_throws_experiment(), - # Requiring true GenerationStrategy here, ideally we will loosen this - # in the future - generation_strategy=assert_is_instance( - self._generation_strategy_or_choose(), GenerationStrategy - ), + experiment=self._experiment, + generation_strategy=self._generation_strategy, use_model_predictions=use_model_predictions, ) ) arm = none_throws( - assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial - ).arm + assert_is_instance(self._experiment.trials[trial_index], Trial).arm ) if use_model_predictions: @@ -727,11 +706,9 @@ def get_best_arm( prediction = {} else: - data_dict = ( - self._none_throws_experiment() - .lookup_data(trial_indices=[trial_index]) - .df.to_dict() - ) + data_dict = self._experiment.lookup_data( + trial_indices=[trial_index] + ).df.to_dict() prediction = { data_dict["metric_name"][i]: (data_dict["mean"][i], data_dict["sem"][i]) @@ -753,19 +730,17 @@ def get_pareto_frontier( outcome. """ - if len(self._none_throws_experiment().trials) < 1: + if len(self._experiment.trials) < 1: raise UnsupportedError( "No trials have been run yet. Please run at least one trial before " "calling get_pareto_frontier." ) frontier = BestPointMixin._get_pareto_optimal_parameters( - experiment=self._none_throws_experiment(), + experiment=self._experiment, # Requiring true GenerationStrategy here, ideally we will loosen this # in the future - generation_strategy=assert_is_instance( - self._generation_strategy_or_choose(), GenerationStrategy - ), + generation_strategy=self._generation_strategy, use_model_predictions=use_model_predictions, ) @@ -773,9 +748,7 @@ def get_pareto_frontier( arm_names = [ none_throws( - assert_is_instance( - self._none_throws_experiment().trials[trial_index], Trial - ).arm + assert_is_instance(self._experiment.trials[trial_index], Trial).arm ).name for trial_index, _ in frontier_list ] @@ -796,11 +769,9 @@ def get_pareto_frontier( else: predictions = [] for trial_index in frontier.keys(): - data_dict = ( - self._none_throws_experiment() - .lookup_data(trial_indices=[trial_index]) - .df.to_dict() - ) + data_dict = self._experiment.lookup_data( + trial_indices=[trial_index] + ).df.to_dict() predictions.append( { @@ -842,9 +813,8 @@ def predict( Returns: A list of mappings from metric name to predicted mean and SEM """ - search_space = self._none_throws_experiment().search_space for parameters in points: - search_space.check_membership( + self._experiment.search_space.check_membership( # pyre-fixme[6]: Core Ax allows users to specify TParameterization # values as None but we do not allow this in the API. parameterization=parameters, @@ -852,10 +822,8 @@ def predict( check_all_parameters_present=True, ) - generation_strategy = self._generation_strategy_or_choose() - try: - mean, covariance = none_throws(generation_strategy.model).predict( + mean, covariance = none_throws(self._generation_strategy.model).predict( observation_features=[ # pyre-fixme[6]: Core Ax allows users to specify TParameterization # values as None but we do not allow this in the API. @@ -916,15 +884,39 @@ def load_from_database( """ ... - def _none_throws_experiment(self) -> Experiment: + # -------------------- Section 5: Private Methods ------------------------------- + # -------------------- Section 5.1: Getters and defaults ------------------------ + @property + def _experiment(self) -> Experiment: return none_throws( - self._experiment, + self._maybe_experiment, ( "Experiment not set. Please call configure_experiment or load an " "experiment before utilizing any other methods on the Client." ), ) + @property + def _generation_strategy(self) -> GenerationStrategy: + return none_throws( + self._maybe_generation_strategy, + ( + "GenerationStrategy not set. Please call " + "configure_generation_strategy, load a GenerationStrategy, or call " + "get_next_trials or run_trials to automatically choose a " + "GenerationStrategy before utilizing any other methods on the Client " + "which require one." + ), + ) + + @property + def _early_stopping_strategy(self) -> BaseEarlyStoppingStrategy: + return none_throws( + self._maybe_early_stopping_strategy, + "Early stopping strategy not set. Please set an early stopping strategy " + "before calling should_stop_trial_early.", + ) + def _generation_strategy_or_choose( self, ) -> GenerationStrategy: @@ -933,12 +925,33 @@ def _generation_strategy_or_choose( return it. """ - if self._generation_strategy is None: + try: + return self._generation_strategy + except AssertionError: self.configure_generation_strategy( generation_strategy_config=GenerationStrategyConfig() ) - return none_throws(self._generation_strategy) + return self._generation_strategy + + def _early_stopping_strategy_or_choose( + self, + ) -> BaseEarlyStoppingStrategy: + """ + If an EarlyStoppingStrategy is not set choose a default one and return it. + """ + + try: + return self._early_stopping_strategy + except AssertionError: + # PercetinleEarlyStoppingStrategy may or may not have sensible defaults at + # current moment -- we will need to be critical of these settings during + # benchmarking + self.set_early_stopping_strategy( + early_stopping_strategy=PercentileEarlyStoppingStrategy() + ) + + return self._early_stopping_strategy def _overwrite_metric(self, metric: Metric) -> None: """ @@ -951,9 +964,7 @@ def _overwrite_metric(self, metric: Metric) -> None: """ # Check the OptimizationConfig first - if ( - optimization_config := self._none_throws_experiment().optimization_config - ) is not None: + if (optimization_config := self._experiment.optimization_config) is not None: # Check the objective if isinstance( multi_objective := optimization_config.objective, MultiObjective @@ -989,14 +1000,13 @@ def _overwrite_metric(self, metric: Metric) -> None: return # Check the tracking metrics - tracking_metric_names = self._none_throws_experiment()._tracking_metrics.keys() - if metric.name in tracking_metric_names: - self._none_throws_experiment()._tracking_metrics[metric.name] = metric + if metric.name in self._experiment._tracking_metrics.keys(): + self._experiment._tracking_metrics[metric.name] = metric return # If an equivalently named Metric does not exist, add it as a tracking # metric. - self._none_throws_experiment().add_tracking_metric(metric=metric) + self._experiment.add_tracking_metric(metric=metric) logger.warning( f"Metric {metric} not found in optimization config, added as tracking " "metric." diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 6df01eb8446..804bde92baa 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -158,7 +158,7 @@ def test_configure_optimization(self) -> None: ) self.assertEqual( - none_throws(client._experiment).optimization_config, + client._experiment.optimization_config, OptimizationConfig( objective=Objective(metric=Metric(name="ne"), minimize=True), outcome_constraints=[ @@ -190,7 +190,7 @@ def test_configure_runner(self) -> None: client.set_experiment(experiment=get_branin_experiment()) client.configure_runner(runner=runner) - self.assertEqual(none_throws(client._experiment).runner, runner) + self.assertEqual(client._experiment.runner, runner) def test_configure_metric(self) -> None: client = Client() @@ -216,9 +216,7 @@ def test_configure_metric(self) -> None: self.assertEqual( custom_metric, - none_throws( - none_throws(client._experiment).optimization_config - ).objective.metric, + none_throws(client._experiment.optimization_config).objective.metric, ) # Test replacing a multi-objective @@ -228,9 +226,7 @@ def test_configure_metric(self) -> None: self.assertIn( custom_metric, assert_is_instance( - none_throws( - none_throws(client._experiment).optimization_config - ).objective, + none_throws(client._experiment.optimization_config).objective, MultiObjective, ).metrics, ) @@ -241,9 +237,7 @@ def test_configure_metric(self) -> None: self.assertIn( custom_metric, assert_is_instance( - none_throws( - none_throws(client._experiment).optimization_config - ).objective, + none_throws(client._experiment.optimization_config).objective, ScalarizedObjective, ).metrics, ) @@ -256,7 +250,7 @@ def test_configure_metric(self) -> None: self.assertEqual( custom_metric, - none_throws(none_throws(client._experiment).optimization_config) + none_throws(client._experiment.optimization_config) .outcome_constraints[0] .metric, ) @@ -265,12 +259,12 @@ def test_configure_metric(self) -> None: client.configure_optimization( objective="foo", ) - none_throws(client._experiment).add_tracking_metric(metric=Metric("custom")) + client._experiment.add_tracking_metric(metric=Metric("custom")) client.configure_metrics(metrics=[custom_metric]) self.assertEqual( custom_metric, - none_throws(client._experiment).tracking_metrics[0], + client._experiment.tracking_metrics[0], ) # Test adding a tracking metric @@ -289,7 +283,7 @@ def test_configure_metric(self) -> None: self.assertEqual( custom_metric, - none_throws(client._experiment).tracking_metrics[0], + client._experiment.tracking_metrics[0], ) def test_set_experiment(self) -> None: @@ -312,9 +306,7 @@ def test_set_optimization_config(self) -> None: optimization_config=optimization_config, ) - self.assertEqual( - none_throws(client._experiment).optimization_config, optimization_config - ) + self.assertEqual(client._experiment.optimization_config, optimization_config) def test_set_generation_strategy(self) -> None: client = Client() @@ -413,14 +405,12 @@ def test_attach_data(self) -> None: client.attach_data(trial_index=trial_index, raw_data={"foo": 1.0}) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.RUNNING, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -440,14 +430,12 @@ def test_attach_data(self) -> None: client.attach_data(trial_index=0, raw_data={"foo": 2.0}, progression=10) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.RUNNING, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -469,14 +457,12 @@ def test_attach_data(self) -> None: raw_data={"foo": 1.0, "bar": 2.0}, ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.RUNNING, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -517,14 +503,12 @@ def test_complete_trial(self) -> None: ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.COMPLETED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -547,15 +531,13 @@ def test_complete_trial(self) -> None: ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.COMPLETED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -576,14 +558,12 @@ def test_complete_trial(self) -> None: client.complete_trial(trial_index=trial_index, raw_data={"foo": 1.0}) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.FAILED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -618,9 +598,7 @@ def test_attach_trial(self) -> None: client.configure_optimization(objective="foo") trial_index = client.attach_trial(parameters={"x1": 0.5}, arm_name="bar") - trial = assert_is_instance( - none_throws(client._experiment).trials[trial_index], Trial - ) + trial = assert_is_instance(client._experiment.trials[trial_index], Trial) self.assertEqual(none_throws(trial.arm).parameters, {"x1": 0.5}) self.assertEqual(none_throws(trial.arm).name, "bar") self.assertEqual(trial.status, TrialStatus.RUNNING) @@ -644,14 +622,12 @@ def test_attach_baseline(self) -> None: client.configure_optimization(objective="foo") trial_index = client.attach_baseline(parameters={"x1": 0.5}) - trial = assert_is_instance( - none_throws(client._experiment).trials[trial_index], Trial - ) + trial = assert_is_instance(client._experiment.trials[trial_index], Trial) self.assertEqual(none_throws(trial.arm).parameters, {"x1": 0.5}) self.assertEqual(none_throws(trial.arm).name, "baseline") self.assertEqual(trial.status, TrialStatus.RUNNING) - self.assertEqual(client._none_throws_experiment().status_quo, trial.arm) + self.assertEqual(client._experiment.status_quo, trial.arm) def test_mark_trial_failed(self) -> None: client = Client() @@ -671,7 +647,7 @@ def test_mark_trial_failed(self) -> None: trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] client.mark_trial_failed(trial_index=trial_index) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.FAILED, ) @@ -693,7 +669,7 @@ def test_mark_trial_abandoned(self) -> None: trial_index = [*client.get_next_trials(maximum_trials=1).keys()][0] client.mark_trial_abandoned(trial_index=trial_index) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.ABANDONED, ) @@ -717,14 +693,12 @@ def test_mark_trial_early_stopped(self) -> None: trial_index=trial_index, raw_data={"foo": 0.0}, progression=1 ) self.assertEqual( - none_throws(client._experiment).trials[trial_index].status, + client._experiment.trials[trial_index].status, TrialStatus.EARLY_STOPPED, ) self.assertTrue( assert_is_instance( - none_throws(client._experiment).lookup_data( - trial_indices=[trial_index] - ), + client._experiment.lookup_data(trial_indices=[trial_index]), MapData, ).map_df.equals( pd.DataFrame( @@ -755,11 +729,6 @@ def test_should_stop_trial_early(self) -> None: ) client.configure_optimization(objective="foo") - with self.assertRaisesRegex( - UnsupportedError, "Early stopping strategy not set" - ): - client.should_stop_trial_early(trial_index=0) - client.set_early_stopping_strategy( early_stopping_strategy=PercentileEarlyStoppingStrategy( metric_names=["foo"] @@ -788,20 +757,18 @@ def test_run_trials(self) -> None: client.run_trials(maximum_trials=4, options=OrchestrationConfig()) - self.assertEqual(len(client._none_throws_experiment().trials), 4) + self.assertEqual(len(client._experiment.trials), 4) self.assertEqual( [ trial.index - for trial in client._none_throws_experiment().trials_by_status[ - TrialStatus.COMPLETED - ] + for trial in client._experiment.trials_by_status[TrialStatus.COMPLETED] ], [0, 1, 2, 3], ) self.assertTrue( assert_is_instance( - client._none_throws_experiment().lookup_data(), + client._experiment.lookup_data(), MapData, ).map_df.equals( pd.DataFrame( @@ -843,13 +810,11 @@ def test_get_next_trials_then_run_trials(self) -> None: _ = client.get_next_trials(maximum_trials=1) self.assertEqual( - len( - client._none_throws_experiment().trials_by_status[TrialStatus.COMPLETED] - ), + len(client._experiment.trials_by_status[TrialStatus.COMPLETED]), 2, ) self.assertEqual( - len(client._none_throws_experiment().trials_by_status[TrialStatus.RUNNING]), + len(client._experiment.trials_by_status[TrialStatus.RUNNING]), 1, ) @@ -860,9 +825,7 @@ def test_get_next_trials_then_run_trials(self) -> None: # All trials should be COMPLETED self.assertEqual( - len( - client._none_throws_experiment().trials_by_status[TrialStatus.COMPLETED] - ), + len(client._experiment.trials_by_status[TrialStatus.COMPLETED]), 5, ) @@ -880,6 +843,9 @@ def test_compute_analyses(self) -> None: ) ) client.configure_optimization(objective="foo") + client.configure_generation_strategy( + generation_strategy_config=GenerationStrategyConfig() + ) with self.assertLogs(logger="ax.analysis", level="ERROR") as lg: cards = client.compute_analyses(analyses=[ParallelCoordinatesPlot()]) @@ -950,11 +916,11 @@ def test_get_best_arm(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) ) @@ -973,11 +939,11 @@ def test_get_best_arm(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) ) @@ -1026,11 +992,11 @@ def test_get_pareto_frontier(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) ) @@ -1054,11 +1020,11 @@ def test_get_pareto_frontier(self) -> None: name, [ none_throws(assert_is_instance(trial, Trial).arm).name - for trial in client._none_throws_experiment().trials.values() + for trial in client._experiment.trials.values() ], ) self.assertTrue( - client._none_throws_experiment().search_space.check_membership( + client._experiment.search_space.check_membership( parameterization=parameters # pyre-ignore[6] ) )