diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index f30f29a54b8..179d611dadb 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -29,7 +29,7 @@ BaseEarlyStoppingStrategy, PercentileEarlyStoppingStrategy, ) -from ax.exceptions.core import UnsupportedError +from ax.exceptions.core import ObjectNotFoundError, UnsupportedError from ax.modelbridge.generation_strategy import GenerationStrategy from ax.preview.api.configs import ( ExperimentConfig, @@ -44,9 +44,11 @@ from ax.preview.api.utils.instantiation.from_string import ( optimization_config_from_string, ) +from ax.preview.api.utils.storage import db_settings_from_storage_config from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy from ax.service.scheduler import Scheduler, SchedulerOptions from ax.service.utils.best_point_mixin import BestPointMixin +from ax.service.utils.with_db_settings_base import WithDBSettingsBase from ax.storage.json_store.decoder import ( generation_strategy_from_json, object_from_json, @@ -66,7 +68,7 @@ logger: Logger = get_logger(__name__) -class Client: +class Client(WithDBSettingsBase): _maybe_experiment: Experiment | None = None _maybe_generation_strategy: GenerationStrategy | None = None _maybe_early_stopping_strategy: BaseEarlyStoppingStrategy | None = None @@ -86,6 +88,13 @@ def __init__( of the experiment's results. If not provided, the random seed will not be set, leading to potentially different results on different runs. """ + + super().__init__( # Initialize WithDBSettingsBase + db_settings=db_settings_from_storage_config(storage_config=storage_config) + if storage_config is not None + else None, + ) + self._storage_config = storage_config self._random_seed = random_seed @@ -109,9 +118,7 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None: self._maybe_experiment = experiment_from_config(config=experiment_config) - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_experiment_to_db_if_possible(experiment=self._experiment) def configure_optimization( self, @@ -154,9 +161,7 @@ def configure_optimization( outcome_constraint_strs=outcome_constraints, ) - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_experiment_to_db_if_possible(experiment=self._experiment) def configure_generation_strategy( self, generation_strategy_config: GenerationStrategyConfig @@ -177,9 +182,9 @@ def configure_generation_strategy( self._maybe_generation_strategy = generation_strategy - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_generation_strategy_to_db_if_possible( + generation_strategy=self._generation_strategy + ) # -------------------- Section 1.1: Configure Automation ------------------------ def configure_runner(self, runner: IRunner) -> None: @@ -212,9 +217,7 @@ def set_experiment(self, experiment: Experiment) -> None: """ self._maybe_experiment = experiment - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_experiment_to_db_if_possible(experiment=self._experiment) def set_optimization_config(self, optimization_config: OptimizationConfig) -> None: """ @@ -228,9 +231,7 @@ def set_optimization_config(self, optimization_config: OptimizationConfig) -> No """ self._experiment.optimization_config = optimization_config - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_experiment_to_db_if_possible(experiment=self._experiment) def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> None: """ @@ -244,11 +245,11 @@ def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> No """ self._maybe_generation_strategy = generation_strategy - none_throws(self._maybe_generation_strategy)._experiment = self._experiment + self._generation_strategy._experiment = self._experiment - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_generation_strategy_to_db_if_possible( + generation_strategy=self._generation_strategy + ) def set_early_stopping_strategy( self, early_stopping_strategy: BaseEarlyStoppingStrategy @@ -265,10 +266,6 @@ def set_early_stopping_strategy( """ self._maybe_early_stopping_strategy = early_stopping_strategy - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... - def _set_runner(self, runner: Runner) -> None: """ This method is not part of the API and is provided (without guarantees of @@ -281,9 +278,9 @@ def _set_runner(self, runner: Runner) -> None: """ self._experiment.runner = runner - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._update_runner_on_experiment_in_db_if_possible( + experiment=self._experiment, runner=runner + ) def _set_metrics(self, metrics: Sequence[Metric]) -> None: """ @@ -303,9 +300,7 @@ def _set_metrics(self, metrics: Sequence[Metric]) -> None: # Check the optimization config first self._overwrite_metric(metric=metric) - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_experiment_to_db_if_possible(experiment=self._experiment) # -------------------- Section 2. Conduct Experiment ---------------------------- def get_next_trials( @@ -366,9 +361,10 @@ def get_next_trials( trials.append(trial) - if self._storage_config is not None: - # TODO[mpolson64] Save trial and update generation strategy - ... + # Bulk save all trials to the database if possible + self._save_or_update_trials_in_db_if_possible( + experiment=self._experiment, trials=trials + ) # pyre-fixme[7]: Core Ax allows users to specify TParameterization values as # None, but we do not allow this in the API. @@ -417,9 +413,9 @@ def complete_trial( ) self.mark_trial_failed(trial_index=trial_index) - if self._storage_config is not None: - # TODO[mpolson64] Save trial - ... + self._save_or_update_trial_in_db_if_possible( + experiment=self._experiment, trial=self._experiment.trials[trial_index] + ) return self._experiment.trials[trial_index].status @@ -452,9 +448,9 @@ def attach_data( combine_with_last_data=True, ) - if self._storage_config is not None: - # TODO[mpolson64] Save trial - ... + self._save_or_update_trial_in_db_if_possible( + experiment=self._experiment, trial=trial + ) # -------------------- Section 2.1 Custom trials -------------------------------- def attach_trial( @@ -477,9 +473,9 @@ def attach_trial( arm_names=[arm_name] if arm_name else None, ) - if self._storage_config is not None: - # TODO[mpolson64] Save trial - ... + self._save_or_update_trial_in_db_if_possible( + experiment=self._experiment, trial=self._experiment.trials[trial_index] + ) return trial_index @@ -506,8 +502,7 @@ def attach_baseline( self._experiment.trials[trial_index], Trial ).arm - if self._storage_config is not None: - ... + self._save_experiment_to_db_if_possible(experiment=self._experiment) return trial_index @@ -542,9 +537,9 @@ def mark_trial_failed(self, trial_index: int) -> None: """ self._experiment.trials[trial_index].mark_failed() - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_or_update_trial_in_db_if_possible( + experiment=self._experiment, trial=self._experiment.trials[trial_index] + ) def mark_trial_abandoned(self, trial_index: int) -> None: """ @@ -556,9 +551,9 @@ def mark_trial_abandoned(self, trial_index: int) -> None: """ self._experiment.trials[trial_index].mark_abandoned() - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_or_update_trial_in_db_if_possible( + experiment=self._experiment, trial=self._experiment.trials[trial_index] + ) def mark_trial_early_stopped( self, trial_index: int, raw_data: TOutcome, progression: int | None = None @@ -579,9 +574,9 @@ def mark_trial_early_stopped( self._experiment.trials[trial_index].mark_early_stopped() - if self._storage_config is not None: - # TODO[mpolson64] Save to database - ... + self._save_or_update_trial_in_db_if_possible( + experiment=self._experiment, trial=self._experiment.trials[trial_index] + ) def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: """ @@ -600,7 +595,9 @@ def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, init_seconds_between_polls=options.initial_seconds_between_polls, ), - # TODO[mpolson64] Add db_settings=self._storage_config when adding storage + db_settings=db_settings_from_storage_config(self._storage_config) + if self._storage_config is not None + else None, ) # Note: This scheduler call will handle storage internally @@ -649,9 +646,9 @@ def compute_analyses( for result in results ] - if self._storage_config is not None: - # TODO[mpolson64] Save cards to database - ... + self._save_analysis_cards_to_db_if_possible( + experiment=self._experiment, analysis_cards=cards + ) return cards @@ -816,20 +813,22 @@ def load_from_json_file( storage_config: StorageConfig | None = None, ) -> Self: """ - Restore an `AxClient` and its state from a JSON-serialized snapshot, + Restore a `Client` and its state from a JSON-serialized snapshot, residing in a .json file by the given path. Returns: - The restored `AxClient`. + The restored `Client`. """ with open(filepath) as file: return cls._from_json_snapshot( snapshot=json.loads(file.read()), storage_config=storage_config ) + @classmethod def load_from_database( - self, + cls, experiment_name: str, + storage_config: StorageConfig | None = None, ) -> Self: """ Restore an `AxClient` and its state from database by the given name. @@ -837,7 +836,32 @@ def load_from_database( Returns: The restored `AxClient`. """ - ... + db_settings_base = WithDBSettingsBase( + db_settings=db_settings_from_storage_config(storage_config=storage_config) + if storage_config is not None + else None + ) + + maybe_experiment, maybe_generation_strategy = ( + db_settings_base._load_experiment_and_generation_strategy( + experiment_name=experiment_name + ) + ) + if (experiment := maybe_experiment) is None: + raise ObjectNotFoundError( + f"Experiment {experiment_name} not found in database. Please check " + "its name is correct, check your StorageConfig is correct, or create " + "a new experiment." + ) + + client = cls(storage_config=storage_config) + client.set_experiment(experiment=experiment) + if maybe_generation_strategy is not None: + client.set_generation_strategy( + generation_strategy=maybe_generation_strategy + ) + + return client # -------------------- Section 5: Private Methods ------------------------------- # -------------------- Section 5.1: Getters and defaults ------------------------ diff --git a/ax/preview/api/configs.py b/ax/preview/api/configs.py index 4f3827769b4..43c8ec2f850 100644 --- a/ax/preview/api/configs.py +++ b/ax/preview/api/configs.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Mapping, Sequence +from typing import Any, Callable, List, Mapping, Sequence from ax.preview.api.types import TParameterValue from ax.storage.registry_bundle import RegistryBundle @@ -161,5 +161,6 @@ class OrchestrationConfig: @dataclass class StorageConfig: + creator: Callable[..., Any] | None = None # pyre-fixme[4] url: str | None = None registry_bundle: RegistryBundle | None = None diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 8bad5df1c14..1f83295ac1d 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -38,10 +38,12 @@ OrchestrationConfig, ParameterType, RangeParameterConfig, + StorageConfig, ) from ax.preview.api.protocols.metric import IMetric from ax.preview.api.protocols.runner import IRunner from ax.preview.api.types import TParameterization +from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -1128,6 +1130,62 @@ def test_json_storage(self) -> None: str(client._generation_strategy), str(other_client._generation_strategy) ) + def test_sql_storage(self) -> None: + init_test_engine_and_session_factory(force_init=True) + client = Client(storage_config=StorageConfig()) + + # Experiment with relatively complicated search space + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + RangeParameterConfig( + name="x2", parameter_type=ParameterType.INT, bounds=(-1, 1) + ), + ChoiceParameterConfig( + name="x3", + parameter_type=ParameterType.STRING, + values=["a", "b"], + ), + ChoiceParameterConfig( + name="x4", + parameter_type=ParameterType.INT, + values=[1, 2, 3], + is_ordered=True, + ), + ChoiceParameterConfig( + name="x5", parameter_type=ParameterType.INT, values=[1] + ), + ], + name="unique_test_experiment", + ) + ) + + # Relatively complicated optimization config + client.configure_optimization( + objective="foo + 2 * bar", outcome_constraints=["baz >= 0"] + ) + + # Specified generation strategy + client.configure_generation_strategy( + generation_strategy_config=GenerationStrategyConfig( + initialization_budget=3, + ) + ) + + other_client = Client.load_from_database( + experiment_name="unique_test_experiment", storage_config=StorageConfig() + ) + + self.assertEqual(client._experiment, other_client._experiment) + # Don't check for deep equality of GenerationStrategy since the other gs will + # not have all its attributes initialized, but ensure they have the same repr + self.assertEqual( + str(client._generation_strategy), str(other_client._generation_strategy) + ) + class DummyRunner(IRunner): @override diff --git a/ax/preview/api/utils/storage.py b/ax/preview/api/utils/storage.py new file mode 100644 index 00000000000..3e6f4254aad --- /dev/null +++ b/ax/preview/api/utils/storage.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.preview.api.configs import StorageConfig +from ax.storage.sqa_store.decoder import Decoder +from ax.storage.sqa_store.encoder import Encoder +from ax.storage.sqa_store.sqa_config import SQAConfig +from ax.storage.sqa_store.structs import DBSettings + + +def db_settings_from_storage_config( + storage_config: StorageConfig, +) -> DBSettings: + """Construct DBSettings (expected by WithDBSettingsBase) from StorageConfig.""" + if (bundle := storage_config.registry_bundle) is not None: + encoder = bundle.encoder + decoder = bundle.decoder + else: + encoder = Encoder(config=SQAConfig()) + decoder = Decoder(config=SQAConfig()) + + return DBSettings( + creator=storage_config.creator, + url=storage_config.url, + encoder=encoder, + decoder=decoder, + ) diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index e3b09d9e034..fa8df73c864 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -11,7 +11,7 @@ from collections.abc import Iterable from logging import INFO, Logger -from typing import Optional +from typing import Optional, Sequence from ax.analysis.analysis import AnalysisCard @@ -19,6 +19,7 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun +from ax.core.runner import Runner from ax.exceptions.core import ( IncompatibleDependencyVersion, ObjectNotFoundError, @@ -69,6 +70,7 @@ _update_generation_strategy, save_analysis_cards, update_properties_on_experiment, + update_runner_on_experiment, ) from ax.storage.sqa_store.sqa_config import SQAConfig from ax.storage.sqa_store.structs import DBSettings @@ -347,7 +349,7 @@ def _save_or_update_trial_in_db_if_possible( def _save_or_update_trials_in_db_if_possible( self, experiment: Experiment, - trials: list[BaseTrial], + trials: Sequence[BaseTrial], reduce_state_generator_runs: bool = False, ) -> bool: """Saves new trials or update existing trials on given experiment if DB @@ -437,6 +439,21 @@ def _update_generation_strategy_in_db_if_possible( return True return False + def _update_runner_on_experiment_in_db_if_possible( + self, experiment: Experiment, runner: Runner + ) -> bool: + if self.db_settings_set: + _update_runner_on_experiment_in_db_if_possible( + experiment=experiment, + runner=runner, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + ) + return True + + return False + def _update_experiment_properties_in_db( self, experiment_with_updated_properties: Experiment, @@ -575,6 +592,23 @@ def _update_generation_strategy_in_db_if_possible( ) +@retry_on_exception( + retries=3, + default_return_on_suppression=False, + exception_types=RETRY_EXCEPTION_TYPES, +) +def _update_runner_on_experiment_in_db_if_possible( + experiment: Experiment, + runner: Runner, + encoder: Encoder, + decoder: Decoder, + suppress_all_errors: bool, # Used by the decorator. +) -> None: + update_runner_on_experiment( + experiment=experiment, runner=runner, encoder=encoder, decoder=decoder + ) + + @retry_on_exception( retries=3, default_return_on_suppression=False, diff --git a/sphinx/source/preview.rst b/sphinx/source/preview.rst index 99edc3f3fea..333876c2dc8 100644 --- a/sphinx/source/preview.rst +++ b/sphinx/source/preview.rst @@ -94,3 +94,11 @@ Dispatch Utils :members: :undoc-members: :show-inheritance: + +Storage Utils +~~~~~~~~~~~~~ + +.. automodule:: ax.preview.api.utils.storage + :members: + :undoc-members: + :show-inheritance: