From 96d60c74938505af0282eaf3439a4cd61e1df2c3 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 18 Dec 2024 12:12:51 -0800 Subject: [PATCH] Have `gen` call into `gen_with_multiple_nodes` instead of `gen_multiple` (#3187) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3187 We have 4 `gen`-s in the gs file right now, this makes everyone sad. We want to be happy for the holidays. This diff is bigger than I like diffs to be sorry about that. Reviewed By: lena-kashtelyan Differential Revision: D67260951 fbshipit-source-id: 9cddb62b59f0347b47bf08177b91ffbafc25aded --- .../plotly/tests/test_insample_effects.py | 22 ++-- .../plotly/tests/test_predicted_effects.py | 43 +++++--- ax/modelbridge/generation_strategy.py | 35 ++++-- .../tests/test_generation_strategy.py | 104 ++++++++++-------- .../tests/test_transition_criterion.py | 19 ++-- ax/service/tests/test_ax_client.py | 8 +- ax/storage/sqa_store/tests/test_sqa_store.py | 18 +-- 7 files changed, 147 insertions(+), 102 deletions(-) diff --git a/ax/analysis/plotly/tests/test_insample_effects.py b/ax/analysis/plotly/tests/test_insample_effects.py index ea944b5c5fd..ce2245c327f 100644 --- a/ax/analysis/plotly/tests/test_insample_effects.py +++ b/ax/analysis/plotly/tests/test_insample_effects.py @@ -45,14 +45,14 @@ def test_compute_uses_gs_model_if_possible(self) -> None: experiment = get_branin_experiment(with_status_quo=True) generation_strategy = self.generation_strategy experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1.0 ).mark_completed(unsafe=True) experiment.fetch_data() - generation_strategy.gen_with_multiple_nodes(experiment=experiment, n=10) + generation_strategy._gen_with_multiple_nodes(experiment=experiment, n=10) # Ensure the current model is Botorch self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") # WHEN we compute the analysis @@ -91,7 +91,7 @@ def test_compute_modeled_can_use_ebts_for_gs_with_non_predictive_model( generation_strategy = self.generation_strategy generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).mark_completed(unsafe=True) @@ -141,7 +141,7 @@ def test_compute_modeled_can_use_ebts_for_no_gs(self) -> None: generation_strategy = self.generation_strategy generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).mark_completed(unsafe=True) @@ -187,7 +187,7 @@ def test_compute_unmodeled_uses_thompson(self) -> None: generation_strategy = self.generation_strategy generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).mark_completed(unsafe=True) @@ -246,7 +246,7 @@ def test_compute_requires_data_for_the_metric_on_the_trial_without_a_model( generation_strategy = self.generation_strategy generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).mark_completed(unsafe=True) @@ -273,7 +273,7 @@ def test_compute_requires_data_for_the_metric_on_the_trial_with_a_model( generation_strategy = self.generation_strategy generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).set_status_quo_with_weight( @@ -282,7 +282,7 @@ def test_compute_requires_data_for_the_metric_on_the_trial_with_a_model( experiment.fetch_data() # AND GIVEN the experiment has a trial with no data empty_trial = experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ), ) @@ -320,7 +320,7 @@ def test_constraints(self) -> None: generation_strategy = self.generation_strategy generation_strategy.experiment = experiment trial = experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ), ) @@ -328,7 +328,7 @@ def test_constraints(self) -> None: trial.mark_completed(unsafe=True) experiment.fetch_data() trial = experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ), ) @@ -400,7 +400,7 @@ def test_level(self) -> None: generation_strategy = self.generation_strategy generation_strategy.experiment = experiment trial = experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ), ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index f89c50f5328..5001f68b1c6 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -84,7 +84,7 @@ def test_compute(self) -> None: experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) generation_strategy = self.generation_strategy experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).set_status_quo_with_weight( @@ -92,7 +92,7 @@ def test_compute(self) -> None: ).mark_completed(unsafe=True) experiment.fetch_data() experiment.new_batch_trial( - generator_runs=generation_strategy.gen_with_multiple_nodes( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10 ) ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) @@ -153,27 +153,31 @@ def test_compute_multitask(self) -> None: experiment = get_branin_experiment(with_status_quo=True) generation_strategy = self.generation_strategy experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1 ).mark_completed(unsafe=True) experiment.fetch_data() experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1 ).mark_completed(unsafe=True) experiment.fetch_data() # leave as a candidate experiment.new_batch_trial( - generator_run=generation_strategy.gen( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) experiment.new_batch_trial( - generator_run=generation_strategy.gen( + generator_runs=generation_strategy._gen_with_multiple_nodes( experiment=experiment, n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), @@ -224,15 +228,21 @@ def test_it_does_not_plot_abandoned_trials(self) -> None: experiment = get_branin_experiment() generation_strategy = self.generation_strategy experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).mark_completed(unsafe=True) experiment.fetch_data() # candidate trial experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ) experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).mark_abandoned() arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) # WHEN we compute the analysis @@ -268,9 +278,10 @@ def test_it_works_for_non_batch_experiments(self) -> None: last_model_key = sobol_key while last_model_key == sobol_key: trial = experiment.new_trial( - generator_run=generation_strategy.gen( - experiment=experiment, n=1, pending_observation=True - ) + generator_run=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, + n=1, + )[0] ) last_model_key = none_throws(trial.generator_run)._model_key if last_model_key == sobol_key: @@ -301,13 +312,17 @@ def test_constraints(self) -> None: ] generation_strategy = self.generation_strategy trial = experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10), + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ), ) trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) trial.mark_completed(unsafe=True) experiment.fetch_data() trial = experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10), + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment, n=10 + ), ) trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) # WHEN we compute the analysis and constraints are violated diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 4634a8ca65a..96716891cdc 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -341,9 +341,9 @@ def gen( self, experiment: Experiment, data: Data | None = None, - n: int = 1, pending_observations: dict[str, list[ObservationFeatures]] | None = None, - **kwargs: Any, + n: int = 1, + fixed_features: ObservationFeatures | None = None, ) -> GeneratorRun: """Produce the next points in the experiment. Additional kwargs passed to this method are propagated directly to the underlying model's `gen`, along @@ -372,16 +372,23 @@ def gen( observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated. """ - return self._gen_multiple( + gr = self._gen_with_multiple_nodes( experiment=experiment, - num_generator_runs=1, data=data, n=n, pending_observations=pending_observations, - **kwargs, - )[0] + fixed_features=fixed_features, + ) + if len(gr) > 1: + raise UnsupportedError( + "By calling into GenerationStrategy.gen(), you are should be " + "expecting a single `Trial` with only one `GeneratorRun`. However, " + "the underlying GenerationStrategy produced multiple `GeneratorRuns` " + f"and returned the following list of `GeneratorRun`s: {gr}" + ) + return gr[0] - def gen_with_multiple_nodes( + def _gen_with_multiple_nodes( self, experiment: Experiment, data: Data | None = None, @@ -392,9 +399,8 @@ def gen_with_multiple_nodes( ) -> list[GeneratorRun]: """Produces a List of GeneratorRuns for a single trial, either ``Trial`` or ``BatchTrial``, and if producing a ``BatchTrial`` allows for multiple - models to be used to generate GeneratorRuns for that trial. + models to be used to generate ``GeneratorRun``s for that trial. - NOTE: This method is in development. Please do not use it yet. Args: experiment: Experiment, for which the generation strategy is producing @@ -433,6 +439,11 @@ def gen_with_multiple_nodes( pending_observations = deepcopy(pending_observations) or {} self.experiment = experiment self._validate_arms_per_node(arms_per_node=arms_per_node) + if self.optimization_complete: + raise GenerationStrategyCompleted( + f"Generation strategy {self} generated all the trials as " + "specified in its nodes." + ) # TODO: @mgarrard update this when gen methods are merged gen_kwargs: dict[str, Any] = {} gen_kwargs = { @@ -562,7 +573,7 @@ def gen_for_multiple_trials_with_multiple_models( num_trials = max(min(num_trials, gr_limit), 1) for _i in range(num_trials): trial_grs.append( - self.gen_with_multiple_nodes( + self._gen_with_multiple_nodes( experiment=experiment, data=data, n=n, @@ -671,7 +682,7 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationStep]) -> None: # Set transition_to field for all but the last step, which remains # null. - if idx != len(self._steps): + if idx < len(self._steps): for transition_criteria in step.transition_criteria: if ( transition_criteria.criterion_class @@ -929,7 +940,7 @@ def _should_continue_gen_for_trial(self) -> bool: """Determine if we should continue generating for the current trial, or end generation for the current trial. Note that generating more would involve transitioning to a next node, because each node generates once per call to - ``GenerationStrategy.gen_with_multiple_nodes``. + ``GenerationStrategy._gen_with_multiple_nodes``. Returns: A boolean which represents if generation for a trial is complete diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index f9ee3427c0d..2a7b6f15a62 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -253,21 +253,23 @@ def setUp(self) -> None: name="Sobol+MBM_Nodes", nodes=[self.sobol_node, self.mbm_node], ) - self.mbm_to_sobol2_max = MinTrials( + self.mbm_to_sobol2_with_running_trial = MinTrials( threshold=1, transition_to="sobol_2", block_transition_if_unmet=True, only_in_statuses=[TrialStatus.RUNNING], use_all_trials_in_exp=True, ) - self.mbm_to_sobol2_min = MinTrials( + self.mbm_to_sobol2_with_completed_trial = MinTrials( threshold=1, transition_to="sobol_2", block_transition_if_unmet=True, only_in_statuses=[TrialStatus.COMPLETED], use_all_trials_in_exp=True, ) - self.mbm_to_sobol_auto = AutoTransitionAfterGen(transition_to="sobol_3") + self.mbm_to_sobol_auto = AutoTransitionAfterGen( + transition_to="sobol_3", continue_trial_generation=False + ) self.competing_tc_gs = GenerationStrategy( nodes=[ GenerationNode( @@ -279,8 +281,8 @@ def setUp(self) -> None: node_name="mbm", model_specs=[self.mbm_model_spec], transition_criteria=[ - self.mbm_to_sobol2_max, - self.mbm_to_sobol2_min, + self.mbm_to_sobol2_with_running_trial, + self.mbm_to_sobol2_with_completed_trial, self.mbm_to_sobol_auto, ], ), @@ -362,6 +364,7 @@ def _get_sobol_mbm_step_gs( model=Models.BOTORCH_MODULAR, num_trials=num_mbm_trials, model_kwargs=self.step_model_kwargs, + enforce_num_trials=True, ), ], ) @@ -417,7 +420,7 @@ def test_validation(self) -> None: factorial_thompson_generation_strategy.uses_non_registered_models ) with self.assertRaises(ValueError): - factorial_thompson_generation_strategy.gen(exp) + factorial_thompson_generation_strategy._gen_with_multiple_nodes(exp) self.assertEqual(GenerationStep(model=sum, num_trials=1).model_name, "sum") with self.assertRaisesRegex(UserInputError, "Maximum parallelism should be"): GenerationStrategy( @@ -586,7 +589,7 @@ def test_sobol_MBM_strategy(self) -> None: self.assertEqual(ms, {"init_position": i + 1, "seed": expected_seed}) # Check completeness error message when GS should be done. with self.assertRaises(GenerationStrategyCompleted): - g = gs.gen(exp) + gs.gen(exp) def test_sobol_MBM_strategy_keep_generating(self) -> None: exp = get_branin_experiment() @@ -618,7 +621,7 @@ def test_sobol_strategy(self) -> None: @patch(f"{Experiment.__module__}.Experiment.fetch_data", return_value=get_data()) def test_factorial_thompson_strategy(self, _: MagicMock) -> None: exp = get_branin_experiment() - factorial_thompson_generation_strategy = GenerationStrategy( + factorial_thompson_gs = GenerationStrategy( steps=[ GenerationStep( model=Models.FACTORIAL, @@ -632,18 +635,24 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None: ), ] ) - self.assertEqual( - factorial_thompson_generation_strategy.name, "Factorial+Thompson" - ) + self.assertEqual(factorial_thompson_gs.name, "Factorial+Thompson") mock_model_bridge = self.mock_discrete_model_bridge.return_value # Initial factorial batch. - exp.new_batch_trial(factorial_thompson_generation_strategy.gen(experiment=exp)) + exp.new_batch_trial( + generator_runs=factorial_thompson_gs._gen_with_multiple_nodes( + experiment=exp + ) + ) args, kwargs = mock_model_bridge._set_kwargs_to_save.call_args self.assertEqual(kwargs.get("model_key"), "Factorial") # Subsequent Thompson sampling batch. - exp.new_batch_trial(factorial_thompson_generation_strategy.gen(experiment=exp)) + exp.new_batch_trial( + generator_runs=factorial_thompson_gs._gen_with_multiple_nodes( + experiment=exp + ) + ) args, kwargs = mock_model_bridge._set_kwargs_to_save.call_args self.assertEqual(kwargs.get("model_key"), "Thompson") @@ -683,16 +692,18 @@ def test_sobol_MBM_strategy_batches(self) -> None: sobol_MBM_generation_strategy = self._get_sobol_mbm_step_gs( num_sobol_trials=1, num_mbm_trials=6 ) - gr = sobol_MBM_generation_strategy.gen(exp, n=2) - exp.new_batch_trial(generator_run=gr).run() + grs = sobol_MBM_generation_strategy._gen_with_multiple_nodes(exp, n=2) + exp.new_batch_trial(generator_runs=grs).run() for i in range(1, 8): if i == 7: # Check completeness error message. with self.assertRaises(GenerationStrategyCompleted): - g = sobol_MBM_generation_strategy.gen(exp, n=2) + grs_2 = sobol_MBM_generation_strategy._gen_with_multiple_nodes( + exp, n=2 + ) else: - g = sobol_MBM_generation_strategy.gen(exp, n=2) - exp.new_batch_trial(generator_run=g).run() + grs_2 = sobol_MBM_generation_strategy._gen_with_multiple_nodes(exp, n=2) + exp.new_batch_trial(generator_runs=grs_2).run() self.assertIsInstance(sobol_MBM_generation_strategy.model, TorchModelBridge) def test_with_factory_function(self) -> None: @@ -1318,14 +1329,14 @@ def test_gs_with_suggested_n_is_zero(self) -> None: # in a cyclic gs dag for _i in range(3): # if you request < 6 arms, repeat arm input constructor will return 0 arms - grs = gs.gen_with_multiple_nodes(experiment=exp, n=5) + grs = gs._gen_with_multiple_nodes(experiment=exp, n=5) self.assertEqual(len(grs), 1) # only generated from one node self.assertEqual(grs[0]._generation_node_name, "sobol_3") self.assertEqual(len(grs[0].arms), 5) # all 5 arms from sobol 3 self.assertTrue(node_2._should_skip) # Now validate that we can get grs from sobol_2 if we request enough n - grs = gs.gen_with_multiple_nodes(experiment=exp, n=8) + grs = gs._gen_with_multiple_nodes(experiment=exp, n=8) self.assertEqual(len(grs), 2) self.assertEqual(grs[0]._generation_node_name, "sobol_2") self.assertEqual(len(grs[0].arms), 1) @@ -1377,7 +1388,7 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: original_method=ModelSpec.gen, ) as model_spec_gen_mock: # Generate a trial that should be composed of arms from 3 nodes - grs = gs.gen_with_multiple_nodes( + grs = gs._gen_with_multiple_nodes( experiment=exp, arms_per_node=arms_per_node ) @@ -1422,7 +1433,7 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: ) # check that we can pass in pending points - grs = gs.gen_with_multiple_nodes( + grs = gs._gen_with_multiple_nodes( experiment=exp, arms_per_node=arms_per_node, pending_observations=original_pending, @@ -1620,7 +1631,9 @@ def test_gs_with_competing_transition_edges(self) -> None: def test_transition_edges(self) -> None: """Test transition_edges property of ``GenerationNode``""" - mbm_to_sobol_auto = AutoTransitionAfterGen(transition_to="sobol") + mbm_to_sobol_auto = AutoTransitionAfterGen( + transition_to="sobol", continue_trial_generation=False + ) gs = GenerationStrategy( nodes=[ GenerationNode( @@ -1632,8 +1645,8 @@ def test_transition_edges(self) -> None: node_name="mbm", model_specs=[self.mbm_model_spec], transition_criteria=[ - self.mbm_to_sobol2_max, - self.mbm_to_sobol2_min, + self.mbm_to_sobol2_with_running_trial, + self.mbm_to_sobol2_with_completed_trial, mbm_to_sobol_auto, ], ), @@ -1653,7 +1666,10 @@ def test_transition_edges(self) -> None: self.assertEqual( gs._curr.transition_edges, { - "sobol_2": [self.mbm_to_sobol2_max, self.mbm_to_sobol2_min], + "sobol_2": [ + self.mbm_to_sobol2_with_running_trial, + self.mbm_to_sobol2_with_completed_trial, + ], "sobol": [mbm_to_sobol_auto], }, ) @@ -1673,7 +1689,7 @@ def test_multiple_arms_per_node(self) -> None: "sobol_4": 4, } with self.assertRaisesRegex(UserInputError, "defined in `arms_per_node`"): - gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) + gs._gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) # now we will check that the first trial contains 3 arms, the second trial # contains 6 arms (2 from mbm, 1 from sobol_2, 3 from sobol_3), and all @@ -1688,7 +1704,7 @@ def test_multiple_arms_per_node(self) -> None: # for the first trial, we start on sobol, we generate the trial, but it hasn't # been run yet, so we remain on sobol trial0 = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) + generator_runs=gs._gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) ) self.assertEqual(len(trial0.arms_by_name), 3) self.assertEqual(trial0.generator_runs[0]._generation_node_name, "sobol") @@ -1700,7 +1716,7 @@ def test_multiple_arms_per_node(self) -> None: # to the last first node in a trial. for _i in range(0, 2): trial = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes( + generator_runs=gs._gen_with_multiple_nodes( exp, arms_per_node=arms_per_node ) ) @@ -1717,7 +1733,7 @@ def test_multiple_arms_per_node(self) -> None: # after running the next trial should be made from sobol 4 trial.run() trial = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) + generator_runs=gs._gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) ) self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4") self.assertEqual(len(trial.generator_runs[0].arms), 4) @@ -1731,7 +1747,7 @@ def test_gen_with_multiple_uses_total_concurrent_arms_for_a_default(self) -> Non gs = GenerationStrategy(nodes=[self.sobol_node], name="test") gs.experiment = exp exp._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS.value] = 3 - grs = gs.gen_with_multiple_nodes(exp) + grs = gs._gen_with_multiple_nodes(exp) self.assertEqual(len(grs), 1) self.assertEqual(len(grs[0].arms), 3) @@ -1784,14 +1800,14 @@ def test_node_gs_with_auto_transitions(self) -> None: # been run yet, so we remain on sobol, after the trial is run, the subsequent # trials should be from node mbm, sobol_2, and sobol_3 self.assertEqual(gs.current_node_name, "sobol") - trial0 = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp)) + trial0 = exp.new_batch_trial(generator_runs=gs._gen_with_multiple_nodes(exp)) self.assertEqual(gs.current_node_name, "sobol") # while here, test the last generator run property on node self.assertEqual(gs.current_node.node_that_generated_last_gr, "sobol") trial0.run() for _i in range(0, 2): - trial = exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp)) + trial = exp.new_batch_trial(generator_runs=gs._gen_with_multiple_nodes(exp)) self.assertEqual(gs.current_node_name, "sobol_3") self.assertEqual(len(trial.generator_runs), 3) self.assertEqual(trial.generator_runs[0]._generation_node_name, "mbm") @@ -1806,7 +1822,7 @@ def test_node_gs_with_auto_transitions_three_phase(self) -> None: # for the first trial, we start on sobol, we generate the trial, but it hasn't # been run yet, so we remain on sobol self.assertEqual(gs_2.current_node_name, "sobol") - trial0 = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp)) + trial0 = exp.new_batch_trial(generator_runs=gs_2._gen_with_multiple_nodes(exp)) self.assertEqual(gs_2.current_node_name, "sobol") trial0.run() @@ -1816,7 +1832,7 @@ def test_node_gs_with_auto_transitions_three_phase(self) -> None: # to the last first node in a trial. for _i in range(0, 2): trial = exp.new_batch_trial( - generator_runs=gs_2.gen_with_multiple_nodes(exp) + generator_runs=gs_2._gen_with_multiple_nodes(exp) ) self.assertEqual(gs_2.current_node_name, "sobol_3") self.assertEqual(len(trial.generator_runs), 3) @@ -1826,7 +1842,7 @@ def test_node_gs_with_auto_transitions_three_phase(self) -> None: # after running the next trial should be made from sobol 4 trial.run() - trial = exp.new_batch_trial(generator_runs=gs_2.gen_with_multiple_nodes(exp)) + trial = exp.new_batch_trial(generator_runs=gs_2._gen_with_multiple_nodes(exp)) self.assertEqual(trial.generator_runs[0]._generation_node_name, "sobol_4") def test_trials_as_df_node_gs(self) -> None: @@ -1843,7 +1859,7 @@ def test_trials_as_df_node_gs(self) -> None: self.assertIsNone(gs.trials_as_df) # Now the trial should appear in the DF. trial = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) + generator_runs=gs._gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) ) trials_df = none_throws(gs.trials_as_df) self.assertFalse(trials_df.empty) @@ -1857,7 +1873,7 @@ def test_trials_as_df_node_gs(self) -> None: # Add a new trial which will be generated from multiple nodes, and check that # is properly reflected in the DF. trial = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) + generator_runs=gs._gen_with_multiple_nodes(exp, arms_per_node=arms_per_node) ) self.assertEqual( none_throws(gs.trials_as_df).head()["Generation Nodes"][1], @@ -1900,7 +1916,7 @@ def test_gs_with_fixed_features_constructor(self) -> None: # The first trial is our exploration trial, all arms should be generated from # the sobol node due to the input constructor == ALL_N. trial0 = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, n=9) + generator_runs=gs._gen_with_multiple_nodes(exp, n=9) ) self.assertEqual(len(trial0.arms_by_name), 9) self.assertEqual(trial0.generator_runs[0]._generation_node_name, "sobol_node") @@ -1910,7 +1926,9 @@ def test_gs_with_fixed_features_constructor(self) -> None: mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", original_method=ModelSpec.gen, ) as model_spec_gen_mock: - exp.new_batch_trial(generator_runs=gs.gen_with_multiple_nodes(exp, n=9)) + exp.new_batch_trial( + generator_runs=gs._gen_with_multiple_nodes(exp, n=9) + ) fixed_features_in_gen = model_spec_gen_mock.call_args_list[ 0 ].kwargs.get("fixed_features") @@ -1929,7 +1947,7 @@ def test_gs_with_fixed_features_constructor(self) -> None: parameters={}, trial_index=4 ) exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes( + generator_runs=gs._gen_with_multiple_nodes( exp, n=9, fixed_features=passed_fixed_features ) ) @@ -2039,7 +2057,7 @@ def test_gs_with_input_constructor(self) -> None: # The first trial is our exploration trial, all arms should be generated from # the sobol node due to the input constructor == ALL_N. trial0 = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, n=9) + generator_runs=gs._gen_with_multiple_nodes(exp, n=9) ) self.assertEqual(len(trial0.arms_by_name), 9) self.assertEqual(trial0.generator_runs[0]._generation_node_name, "sobol_node") @@ -2049,7 +2067,7 @@ def test_gs_with_input_constructor(self) -> None: # subsequent trials should be generated from sobol_2 and sobol_3, with # sobol_2 generating 1 arm and sobol_3 generating the remaining 8 arms. trial = exp.new_batch_trial( - generator_runs=gs.gen_with_multiple_nodes(exp, n=9) + generator_runs=gs._gen_with_multiple_nodes(exp, n=9) ) self.assertEqual(gs.current_node_name, "sobol_3") self.assertEqual(len(trial.arms_by_name), 9) diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 6096b7fffe3..c3b4ba043a0 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -157,7 +157,7 @@ def test_aux_experiment_check_in_gs(self) -> None: self.assertEqual(gs.current_node_name, "sobol_1") # Do not transition because no aux experiment - grs = gs.gen_with_multiple_nodes(experiment=experiment, n=5) + grs = gs._gen_with_multiple_nodes(experiment=experiment, n=5) self.assertEqual(gs.current_node_name, "sobol_1") self.assertEqual(len(grs), 1) self.assertEqual(len(grs[0].arms), 5) @@ -166,19 +166,19 @@ def test_aux_experiment_check_in_gs(self) -> None: experiment.auxiliary_experiments_by_purpose = { TestAuxiliaryExperimentPurpose.TestAuxExpPurpose: [aux_exp], } - grs = gs.gen_with_multiple_nodes(experiment=experiment, n=5) + grs = gs._gen_with_multiple_nodes(experiment=experiment, n=5) self.assertEqual(gs.current_node_name, "sobol_2") self.assertEqual(len(grs), 1) self.assertEqual(len(grs[0].arms), 5) # Do not move even when the aux exp is still there - grs = gs.gen_with_multiple_nodes(experiment=experiment, n=5) + grs = gs._gen_with_multiple_nodes(experiment=experiment, n=5) self.assertEqual(gs.current_node_name, "sobol_2") self.assertEqual(len(grs), 1) self.assertEqual(len(grs[0].arms), 5) # Remove the aux experiment and move back to sobol_1 experiment.auxiliary_experiments_by_purpose = {} - grs = gs.gen_with_multiple_nodes(experiment=experiment, n=5) + grs = gs._gen_with_multiple_nodes(experiment=experiment, n=5) self.assertEqual(gs.current_node_name, "sobol_1") self.assertEqual(len(grs), 1) self.assertEqual(len(grs[0].arms), 5) @@ -279,8 +279,6 @@ def test_min_trials_is_met(self) -> None: # Need to add trials to test the transition criteria `is_met` method for _i in range(4): experiment.new_trial(gs.gen(experiment=experiment)) - - # TODO: @mgarrard More comprehensive test of trials_from_node node_0_trials = gs._steps[0].trials_from_node node_1_trials = gs._steps[1].trials_from_node @@ -341,8 +339,8 @@ def test_auto_transition(self) -> None: ) gs.experiment = experiment self.assertEqual(gs.current_node_name, "sobol_1") - gs.gen(experiment=experiment) - gs.gen(experiment=experiment) + gs._gen_with_multiple_nodes(experiment=experiment) + gs._gen_with_multiple_nodes(experiment=experiment) self.assertEqual(gs.current_node_name, "sobol_2") def test_auto_with_should_skip_node(self) -> None: @@ -457,7 +455,6 @@ def test_max_trials_is_met(self) -> None: curr_node=gs._steps[0], ) ) - # After adding trials, should pass for _i in range(4): experiment.new_trial(gs.gen(experiment=experiment)) @@ -469,7 +466,6 @@ def test_max_trials_is_met(self) -> None: curr_node=gs._steps[0], ) ) - # Check not in statuses and only in statuses max_criterion_not_in_statuses = MaxTrials( threshold=2, @@ -492,7 +488,6 @@ def test_max_trials_is_met(self) -> None: experiment=experiment, curr_node=gs._steps[0] ) ) - # set 3 of the 4 trials to status == completed for _idx, trial in experiment.trials.items(): trial._status = TrialStatus.COMPLETED @@ -510,7 +505,7 @@ def test_max_trials_is_met(self) -> None: ) def test_trials_from_node_empty(self) -> None: - """Tests MinTrials and MaxTrials default to experiment + """Tests MinTrials defaults to experiment level trials when trials_from_node is None. """ experiment = get_experiment() diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 4cf6825f326..c78c35bee9e 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -2999,10 +2999,12 @@ def test_gen_fixed_features(self) -> None: name="fixed_features", ) with mock.patch.object( - GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen + GenerationStrategy, + "gen", + wraps=ax_client.generation_strategy.gen, ) as mock_gen: with self.subTest("fixed_features is None"): - params, idx = ax_client.get_next_trial() + ax_client.get_next_trial() call_kwargs = mock_gen.call_args_list[0][1] ff = call_kwargs["fixed_features"] self.assertIsNone(ff) @@ -3010,7 +3012,7 @@ def test_gen_fixed_features(self) -> None: fixed_features = FixedFeatures( parameters={"x": 0.0, "y": 5.0}, trial_index=0 ) - params, idx = ax_client.get_next_trial(fixed_features=fixed_features) + ax_client.get_next_trial(fixed_features=fixed_features) call_kwargs = mock_gen.call_args_list[1][1] ff = call_kwargs["fixed_features"] self.assertEqual(ff.parameters, fixed_features.parameters) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 84187846a4f..7ee7bb93d41 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -595,8 +595,8 @@ def wrapper(*args: Any, **kwargs: Any) -> T: gs = get_generation_strategy( with_experiment=True, with_callable_model_kwarg=False ) - gs.gen(gs.experiment) - gs.gen(gs.experiment) + gs.gen(experiment=gs.experiment) + gs.gen(experiment=gs.experiment) save_experiment(gs.experiment) save_generation_strategy(gs) @@ -810,7 +810,7 @@ def test_EncodeDecode(self) -> None: def test_EncodeGeneratorRunReducedState(self) -> None: exp = get_branin_experiment() gs = get_generation_strategy(with_callable_model_kwarg=False) - gr = gs.gen(exp) + gr = gs.gen(experiment=exp) for key in [attr.key for attr in GR_LARGE_MODEL_ATTRS]: self.assertIsNotNone(getattr(gr, f"_{key}")) @@ -833,7 +833,7 @@ def test_EncodeGeneratorRunReducedState(self) -> None: def test_load_and_save_generator_run_reduced_state(self) -> None: exp = get_branin_experiment() gs = get_generation_strategy(with_callable_model_kwarg=False) - gr = gs.gen(exp) + gr = gs.gen(experiment=exp) original_gen_metadata = {"foo": "bar"} gr._gen_metadata = original_gen_metadata exp.new_trial(generator_run=gr) @@ -1612,8 +1612,12 @@ def test_EncodeDecodeGenerationNodeGSWithAdvancedSettings(self) -> None: generation_strategy = sobol_gpei_generation_node_gs( with_input_constructors_all_n=True ) - experiment.new_trial(generation_strategy.gen(experiment=experiment)) - generation_strategy.gen(experiment, data=get_branin_data()) + experiment.new_batch_trial( + generator_runs=generation_strategy._gen_with_multiple_nodes( + experiment=experiment + ) + ) + generation_strategy._gen_with_multiple_nodes(experiment, data=get_branin_data()) save_experiment(experiment) save_generation_strategy(generation_strategy=generation_strategy) @@ -2150,7 +2154,7 @@ def test_GeneratorRunValidatedFields(self) -> None: # experiment loading. exp = get_branin_experiment() gs = get_generation_strategy(with_callable_model_kwarg=False) - trial = exp.new_trial(gs.gen(exp)) + trial = exp.new_trial(gs.gen(experiment=exp)) for instrumented_attr in GR_LARGE_MODEL_ATTRS: self.assertIsNotNone( getattr(trial.generator_run, f"_{instrumented_attr.key}")