Skip to content

Commit

Permalink
Accept any GSInterface in scheduler (facebook#2033)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebook#2033

Reviewed By: lena-kashtelyan

Differential Revision: D51307866
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 30, 2023
1 parent 5432b24 commit 60137ca
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 34 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def benchmark_replication(
scheduler.run_n_trials(max_trials=problem.num_trials)

optimization_trace = np.array(scheduler.get_trace())
num_baseline_trials = scheduler.generation_strategy._steps[0].num_trials
num_baseline_trials = scheduler.standard_generation_strategy._steps[0].num_trials
score_trace = compute_score_trace(
optimization_trace=optimization_trace,
num_baseline_trials=num_baseline_trials,
Expand Down
53 changes: 32 additions & 21 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import ax.service.utils.early_stopping as early_stopping_utils
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -161,7 +162,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin):
"""

experiment: Experiment
generation_strategy: GenerationStrategy
generation_strategy: GenerationStrategyInterface
options: SchedulerOptions
logger: LoggerAdapter
# Mapping of form {short string identifier -> message to show in reported
Expand Down Expand Up @@ -205,7 +206,7 @@ class Scheduler(WithDBSettingsBase, BestPointMixin):
def __init__(
self,
experiment: Experiment,
generation_strategy: GenerationStrategy,
generation_strategy: GenerationStrategyInterface,
options: SchedulerOptions,
db_settings: Optional[DBSettings] = None,
_skip_experiment_save: bool = False,
Expand All @@ -220,7 +221,7 @@ def __init__(

if not isinstance(experiment, Experiment):
raise TypeError("{experiment} is not an Ax experiment.")
if not isinstance(generation_strategy, GenerationStrategy):
if not isinstance(generation_strategy, GenerationStrategyInterface):
raise TypeError("{generation_strategy} is not a generation strategy.")
self._validate_options(options=options)

Expand Down Expand Up @@ -381,6 +382,21 @@ def runner(self) -> Runner:
"""
return not_none(self.experiment.runner)

@property
def standard_generation_strategy(self) -> GenerationStrategy:
"""Used for operations in the scheduler that can only be done with
and instance of ``GenerationStrategy``.
"""
gs = self.generation_strategy
if not isinstance(gs, GenerationStrategy):
raise NotImplementedError(
"This functionality is only supported with instances of "
"`GenerationStrategy` (one that uses `GenerationStrategy` "
"class) and not yet with other types of "
"`GenerationStrategyInterface`."
)
return gs

def __repr__(self) -> str:
"""Short user-friendly string representation."""
if not hasattr(self, "experiment"):
Expand Down Expand Up @@ -446,7 +462,7 @@ def get_best_trial(
) -> Optional[Tuple[int, TParameterization, Optional[TModelPredictArm]]]:
return self._get_best_trial(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand All @@ -461,7 +477,7 @@ def get_pareto_optimal_parameters(
) -> Optional[Dict[int, Tuple[TParameterization, TModelPredictArm]]]:
return self._get_pareto_optimal_parameters(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand All @@ -476,7 +492,7 @@ def get_hypervolume(
) -> float:
return BestPointMixin._get_hypervolume(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand Down Expand Up @@ -1449,12 +1465,9 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]:
Returns:
List of trials, empty if generation is not possible.
"""
pending = get_pending_observation_features_based_on_trial_status(
experiment=self.experiment
)
try:
generator_runs = self._gen_new_trials_from_generation_strategy(
num_trials=num_trials, n=n, pending=pending
num_trials=num_trials, n=n
)
except OptimizationComplete as err:
completion_str = f"Optimization complete: {err}"
Expand Down Expand Up @@ -1483,42 +1496,40 @@ def _get_next_trials(self, num_trials: int = 1, n: int = 1) -> List[BaseTrial]:
self.logger.debug(f"Message from generation strategy: {err}")
return []

if (
self.options.trial_type == TrialType.TRIAL
and len(generator_runs[0].arms) > 1
if self.options.trial_type == TrialType.TRIAL and any(
len(generator_run_list[0].arms) > 1 or len(generator_run_list) > 1
for generator_run_list in generator_runs
):
raise SchedulerInternalError(
"Generation strategy produced multiple arms when only one was expected."
)

return [
self.experiment.new_batch_trial(
generator_run=generator_run,
generator_runs=generator_run_list,
ttl_seconds=self.options.ttl_seconds_for_trials,
)
if self.options.trial_type == TrialType.BATCH_TRIAL
else self.experiment.new_trial(
generator_run=generator_run,
generator_run=generator_run_list[0],
ttl_seconds=self.options.ttl_seconds_for_trials,
)
for generator_run in generator_runs
for generator_run_list in generator_runs
]

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
n: int,
pending: Optional[Dict[str, List[ObservationFeatures]]],
) -> List[GeneratorRun]:
) -> List[List[GeneratorRun]]:
"""Generates a list ``GeneratorRun``s of length of ``num_trials`` using the
``_gen_multiple`` method of the scheduler's ``generation_strategy``, taking
into account any ``pending`` observations.
"""
return self.generation_strategy._gen_multiple(
return self.generation_strategy.gen_for_multiple_trials_with_multiple_models(
experiment=self.experiment,
num_generator_runs=num_trials,
n=n,
pending_observations=pending,
)

def _update_and_save_trials(
Expand Down Expand Up @@ -1853,7 +1864,7 @@ def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge:
Returns:
A ModelBridge object fitted to the observations of the scheduler's experiment.
"""
gs = scheduler.generation_strategy # GenerationStrategy
gs = scheduler.standard_generation_strategy
model_bridge = gs.model # Optional[ModelBridge]
if model_bridge is None: # Need to re-fit the model.
data = scheduler.experiment.fetch_data()
Expand Down
39 changes: 32 additions & 7 deletions ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_branin_search_space,
get_generator_run,
get_sobol,
SpecialGenerationStrategy,
)

from pyre_extensions import none_throws
Expand Down Expand Up @@ -350,9 +351,10 @@ def test_validate_early_stopping_strategy(self) -> None:
),
)

@patch(
f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
return_value=[get_generator_run()],
@patch.object(
GenerationStrategy,
"gen_for_multiple_trials_with_multiple_models",
return_value=[[get_generator_run()]],
)
def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None:
scheduler = Scheduler(
Expand All @@ -366,7 +368,7 @@ def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None:
@patch(
# Record calls to function, but still execute it.
(
f"{Scheduler.__module__}."
f"{GenerationStrategy.__module__}."
"get_pending_observation_features_based_on_trial_status"
),
side_effect=get_pending_observation_features_based_on_trial_status,
Expand Down Expand Up @@ -987,8 +989,9 @@ def test_base_report_results(self) -> None:
)
self.assertEqual(scheduler.run_n_trials(max_trials=3), OptimizationResult())

@patch(
f"{GenerationStrategy.__module__}.GenerationStrategy._gen_multiple",
@patch.object(
GenerationStrategy,
"gen_for_multiple_trials_with_multiple_models",
side_effect=OptimizationComplete("test error"),
)
def test_optimization_complete(self, _) -> None:
Expand Down Expand Up @@ -1327,7 +1330,7 @@ def test_get_fitted_model_bridge(self) -> None:
with patch.object(
GenerationStrategy,
"_fit_current_model",
wraps=scheduler.generation_strategy._fit_current_model,
wraps=generation_strategy._fit_current_model,
) as fit_model:
get_fitted_model_bridge(scheduler)
fit_model.assert_called_once()
Expand Down Expand Up @@ -1359,3 +1362,25 @@ def test_get_fitted_model_bridge(self) -> None:
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)

def test_standard_generation_strategy(self) -> None:
with self.subTest("with a `GenerationStrategy"):
# Tests standard GS creation.
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=self.sobol_GPEI_GS,
options=SchedulerOptions(),
)
self.assertEqual(scheduler.standard_generation_strategy, self.sobol_GPEI_GS)

with self.subTest("with a `SpecialGenerationStrategy`"):
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=SpecialGenerationStrategy(),
options=SchedulerOptions(),
)
with self.assertRaisesRegex(
NotImplementedError,
"only supported with instances of `GenerationStrategy`",
):
scheduler.standard_generation_strategy
1 change: 0 additions & 1 deletion ax/telemetry/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from math import inf

from ax.modelbridge.generation_strategy import GenerationStrategy

from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS


Expand Down
4 changes: 2 additions & 2 deletions ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord:
),
generation_strategy_created_record=(
GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=scheduler.generation_strategy
generation_strategy=scheduler.standard_generation_strategy
)
),
scheduler_total_trials=scheduler.options.total_trials,
Expand All @@ -68,7 +68,7 @@ def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord:
),
transformed_dimensionality=_get_max_transformed_dimensionality(
search_space=scheduler.experiment.search_space,
generation_strategy=scheduler.generation_strategy,
generation_strategy=scheduler.standard_generation_strategy,
),
)

Expand Down
4 changes: 2 additions & 2 deletions ax/telemetry/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None:
),
generation_strategy_created_record=(
GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=scheduler.generation_strategy
generation_strategy=scheduler.standard_generation_strategy
)
),
scheduler_total_trials=0,
Expand All @@ -63,7 +63,7 @@ def test_scheduler_created_record_from_scheduler(self) -> None:
experiment=scheduler.experiment
).__dict__,
**GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=scheduler.generation_strategy
generation_strategy=scheduler.standard_generation_strategy
).__dict__,
"scheduler_total_trials": 0,
"scheduler_max_pending_trials": 10,
Expand Down

0 comments on commit 60137ca

Please sign in to comment.