Skip to content

Commit

Permalink
Have gen call into gen_with_multiple_nodes instead of `gen_multip…
Browse files Browse the repository at this point in the history
…le` (#3187)

Summary:

We have 4 `gen`-s in the gs file right now, this makes everyone sad. We want to be happy for the holidays.

This diff is bigger than I like diffs to be sorry about that.

Reviewed By: lena-kashtelyan

Differential Revision: D67260951
  • Loading branch information
mgarrard authored and facebook-github-bot committed Dec 18, 2024
1 parent 26e2bb4 commit 72aeb6d
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 102 deletions.
22 changes: 11 additions & 11 deletions ax/analysis/plotly/tests/test_insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def test_compute_uses_gs_model_if_possible(self) -> None:
experiment = get_branin_experiment(with_status_quo=True)
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1.0
).mark_completed(unsafe=True)
experiment.fetch_data()
generation_strategy.gen_with_multiple_nodes(experiment=experiment, n=10)
generation_strategy._gen_with_multiple_nodes(experiment=experiment, n=10)
# Ensure the current model is Botorch
self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch")
# WHEN we compute the analysis
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_compute_modeled_can_use_ebts_for_gs_with_non_predictive_model(
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).mark_completed(unsafe=True)
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_compute_modeled_can_use_ebts_for_no_gs(self) -> None:
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).mark_completed(unsafe=True)
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_compute_unmodeled_uses_thompson(self) -> None:
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).mark_completed(unsafe=True)
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_compute_requires_data_for_the_metric_on_the_trial_without_a_model(
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).mark_completed(unsafe=True)
Expand All @@ -273,7 +273,7 @@ def test_compute_requires_data_for_the_metric_on_the_trial_with_a_model(
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
Expand All @@ -282,7 +282,7 @@ def test_compute_requires_data_for_the_metric_on_the_trial_with_a_model(
experiment.fetch_data()
# AND GIVEN the experiment has a trial with no data
empty_trial = experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
),
)
Expand Down Expand Up @@ -320,15 +320,15 @@ def test_constraints(self) -> None:
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
trial = experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
),
)
trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
trial.mark_completed(unsafe=True)
experiment.fetch_data()
trial = experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
),
)
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_level(self) -> None:
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
trial = experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
),
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
Expand Down
43 changes: 29 additions & 14 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def test_compute(self) -> None:
experiment.add_tracking_metric(get_branin_metric(name="tracking_branin"))
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1.0
).mark_completed(unsafe=True)
experiment.fetch_data()
experiment.new_batch_trial(
generator_runs=generation_strategy.gen_with_multiple_nodes(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
Expand Down Expand Up @@ -153,27 +153,31 @@ def test_compute_multitask(self) -> None:
experiment = get_branin_experiment(with_status_quo=True)
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(unsafe=True)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(unsafe=True)
experiment.fetch_data()
# leave as a candidate
experiment.new_batch_trial(
generator_run=generation_strategy.gen(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment,
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
experiment.new_batch_trial(
generator_run=generation_strategy.gen(
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment,
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
Expand Down Expand Up @@ -224,15 +228,21 @@ def test_it_does_not_plot_abandoned_trials(self) -> None:
experiment = get_branin_experiment()
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).mark_completed(unsafe=True)
experiment.fetch_data()
# candidate trial
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
)
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
)
).mark_abandoned()
arms_with_data = set(experiment.lookup_data().df["arm_name"].unique())
# WHEN we compute the analysis
Expand Down Expand Up @@ -268,9 +278,10 @@ def test_it_works_for_non_batch_experiments(self) -> None:
last_model_key = sobol_key
while last_model_key == sobol_key:
trial = experiment.new_trial(
generator_run=generation_strategy.gen(
experiment=experiment, n=1, pending_observation=True
)
generator_run=generation_strategy._gen_with_multiple_nodes(
experiment=experiment,
n=1,
)[0]
)
last_model_key = none_throws(trial.generator_run)._model_key
if last_model_key == sobol_key:
Expand Down Expand Up @@ -301,13 +312,17 @@ def test_constraints(self) -> None:
]
generation_strategy = self.generation_strategy
trial = experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10),
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
),
)
trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
trial.mark_completed(unsafe=True)
experiment.fetch_data()
trial = experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10),
generator_runs=generation_strategy._gen_with_multiple_nodes(
experiment=experiment, n=10
),
)
trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
# WHEN we compute the analysis and constraints are violated
Expand Down
35 changes: 23 additions & 12 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,9 @@ def gen(
self,
experiment: Experiment,
data: Data | None = None,
n: int = 1,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
**kwargs: Any,
n: int = 1,
fixed_features: ObservationFeatures | None = None,
) -> GeneratorRun:
"""Produce the next points in the experiment. Additional kwargs passed to
this method are propagated directly to the underlying model's `gen`, along
Expand Down Expand Up @@ -372,16 +372,23 @@ def gen(
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
"""
return self._gen_multiple(
gr = self._gen_with_multiple_nodes(
experiment=experiment,
num_generator_runs=1,
data=data,
n=n,
pending_observations=pending_observations,
**kwargs,
)[0]
fixed_features=fixed_features,
)
if len(gr) > 1:
raise UnsupportedError(
"By calling into GenerationStrategy.gen(), you are should be "
"expecting a single `Trial` with only one `GeneratorRun`. However, "
"the underlying GenerationStrategy produced multiple `GeneratorRuns` "
f"and returned the following list of `GeneratorRun`s: {gr}"
)
return gr[0]

def gen_with_multiple_nodes(
def _gen_with_multiple_nodes(
self,
experiment: Experiment,
data: Data | None = None,
Expand All @@ -392,9 +399,8 @@ def gen_with_multiple_nodes(
) -> list[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
``BatchTrial``, and if producing a ``BatchTrial`` allows for multiple
models to be used to generate GeneratorRuns for that trial.
models to be used to generate ``GeneratorRun``s for that trial.
NOTE: This method is in development. Please do not use it yet.
Args:
experiment: Experiment, for which the generation strategy is producing
Expand Down Expand Up @@ -433,6 +439,11 @@ def gen_with_multiple_nodes(
pending_observations = deepcopy(pending_observations) or {}
self.experiment = experiment
self._validate_arms_per_node(arms_per_node=arms_per_node)
if self.optimization_complete:
raise GenerationStrategyCompleted(
f"Generation strategy {self} generated all the trials as "
"specified in its nodes."
)
# TODO: @mgarrard update this when gen methods are merged
gen_kwargs: dict[str, Any] = {}
gen_kwargs = {
Expand Down Expand Up @@ -562,7 +573,7 @@ def gen_for_multiple_trials_with_multiple_models(
num_trials = max(min(num_trials, gr_limit), 1)
for _i in range(num_trials):
trial_grs.append(
self.gen_with_multiple_nodes(
self._gen_with_multiple_nodes(
experiment=experiment,
data=data,
n=n,
Expand Down Expand Up @@ -671,7 +682,7 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationStep]) -> None:

# Set transition_to field for all but the last step, which remains
# null.
if idx != len(self._steps):
if idx < len(self._steps):
for transition_criteria in step.transition_criteria:
if (
transition_criteria.criterion_class
Expand Down Expand Up @@ -929,7 +940,7 @@ def _should_continue_gen_for_trial(self) -> bool:
"""Determine if we should continue generating for the current trial, or end
generation for the current trial. Note that generating more would involve
transitioning to a next node, because each node generates once per call to
``GenerationStrategy.gen_with_multiple_nodes``.
``GenerationStrategy._gen_with_multiple_nodes``.
Returns:
A boolean which represents if generation for a trial is complete
Expand Down
Loading

0 comments on commit 72aeb6d

Please sign in to comment.