From 9a936ae311cce5922eb5680c6c17102356aab48c Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 18 Dec 2024 07:53:59 -0800 Subject: [PATCH] Implement json storage Summary: Implement two new methods: save_to_json_file and load_from_json_file Also renamed DBConfig to StorageConfig, made URL optional, and added optional RegistryBundle field. Reviewed By: lena-kashtelyan Differential Revision: D67159759 --- ax/preview/api/client.py | 192 +++++++++++++++++++++------- ax/preview/api/configs.py | 6 +- ax/preview/api/tests/test_client.py | 57 +++++++++ 3 files changed, 208 insertions(+), 47 deletions(-) diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index a92828d57cb..f30f29a54b8 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -5,8 +5,9 @@ # pyre-strict +import json from logging import Logger -from typing import Sequence +from typing import Any, Sequence import numpy as np @@ -31,10 +32,10 @@ from ax.exceptions.core import UnsupportedError from ax.modelbridge.generation_strategy import GenerationStrategy from ax.preview.api.configs import ( - DatabaseConfig, ExperimentConfig, GenerationStrategyConfig, OrchestrationConfig, + StorageConfig, ) from ax.preview.api.protocols.metric import IMetric from ax.preview.api.protocols.runner import IRunner @@ -46,6 +47,17 @@ from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy from ax.service.scheduler import Scheduler, SchedulerOptions from ax.service.utils.best_point_mixin import BestPointMixin +from ax.storage.json_store.decoder import ( + generation_strategy_from_json, + object_from_json, +) +from ax.storage.json_store.encoder import object_to_json +from ax.storage.json_store.registry import ( + CORE_CLASS_DECODER_REGISTRY, + CORE_CLASS_ENCODER_REGISTRY, + CORE_DECODER_REGISTRY, + CORE_ENCODER_REGISTRY, +) from ax.utils.common.logger import get_logger from ax.utils.common.random import with_rng_seed from pyre_extensions import assert_is_instance, none_throws @@ -61,20 +73,20 @@ class Client: def __init__( self, - db_config: DatabaseConfig | None = None, + storage_config: StorageConfig | None = None, random_seed: int | None = None, ) -> None: """ Initialize a Client, which manages state across the lifecycle of an experiment. Args: - db_config: Configuration for saving to and loading from a database. If + storage_config: Configuration for saving to and loading from a database. If elided the experiment will not automatically be saved to a database. random_seed: An optional integer to set the random seed for reproducibility of the experiment's results. If not provided, the random seed will not be set, leading to potentially different results on different runs. """ - self._db_config = db_config + self._storage_config = storage_config self._random_seed = random_seed # -------------------- Section 1: Configure -------------------------------------- @@ -87,7 +99,7 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None: This method only constitutes defining the search space and misc. metadata like name, description, and owners. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ if self._maybe_experiment is not None: raise UnsupportedError( @@ -97,7 +109,7 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None: self._maybe_experiment = experiment_from_config(config=experiment_config) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -134,7 +146,7 @@ def configure_optimization( Note that scalarized outcome constraints cannot be relative. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._experiment.optimization_config = optimization_config_from_string( @@ -142,7 +154,7 @@ def configure_optimization( outcome_constraint_strs=outcome_constraints, ) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -153,7 +165,7 @@ def configure_generation_strategy( Overwrite the existing GenerationStrategy by calling choose_gs using the arguments of the GenerationStrategyConfig as parameters. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ generation_strategy = choose_generation_strategy( @@ -165,7 +177,7 @@ def configure_generation_strategy( self._maybe_generation_strategy = generation_strategy - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -174,7 +186,7 @@ def configure_runner(self, runner: IRunner) -> None: """ Attaches a Runner to the Experiment. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._set_runner(runner=runner) @@ -196,11 +208,11 @@ def set_experiment(self, experiment: Experiment) -> None: Overwrite the existing Experiment with the provided Experiment. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._maybe_experiment = experiment - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -212,11 +224,11 @@ def set_optimization_config(self, optimization_config: OptimizationConfig) -> No Overwrite the existing OptimizationConfig with the provided OptimizationConfig. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._experiment.optimization_config = optimization_config - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -228,13 +240,13 @@ def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> No Overwrite the existing GenerationStrategy with the provided GenerationStrategy. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._maybe_generation_strategy = generation_strategy none_throws(self._maybe_generation_strategy)._experiment = self._experiment - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -249,11 +261,11 @@ def set_early_stopping_strategy( Overwrite the existing EarlyStoppingStrategy with the provided EarlyStoppingStrategy. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._maybe_early_stopping_strategy = early_stopping_strategy - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -265,11 +277,11 @@ def _set_runner(self, runner: Runner) -> None: Attaches a Runner to the Experiment. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._experiment.runner = runner - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -291,7 +303,7 @@ def _set_metrics(self, metrics: Sequence[Metric]) -> None: # Check the optimization config first self._overwrite_metric(metric=metric) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -308,7 +320,7 @@ def get_next_trials( This will need to be rethought somewhat when we add support for BatchTrials, but will be okay for current supported functionality. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. Returns: A mapping of trial index to parameterization. @@ -354,7 +366,7 @@ def get_next_trials( trials.append(trial) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save trial and update generation strategy ... @@ -377,7 +389,7 @@ def complete_trial( - If any metrics on the OptimizationConfig are missing the trial will be marked as FAILED - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ if raw_data is not None: self.attach_data( @@ -405,7 +417,7 @@ def complete_trial( ) self.mark_trial_failed(trial_index=trial_index) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save trial ... @@ -423,7 +435,7 @@ def attach_data( tracking metrics. If progression is provided the Experiment will be updated to use MapData and the data will be attached to the appropriate step. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ # If no progression is provided assume the data is not timeseries-like and @@ -440,7 +452,7 @@ def attach_data( combine_with_last_data=True, ) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save trial ... @@ -453,7 +465,7 @@ def attach_trial( The trial will be marked as RUNNING and must be completed manually by the user. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. Returns: The index of the attached trial. @@ -465,7 +477,7 @@ def attach_trial( arm_names=[arm_name] if arm_name else None, ) - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save trial ... @@ -483,7 +495,7 @@ def attach_baseline( Returns: The index of the attached trial. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ trial_index = self.attach_trial( parameters=parameters, @@ -494,7 +506,7 @@ def attach_baseline( self._experiment.trials[trial_index], Trial ).arm - if self._db_config is not None: + if self._storage_config is not None: ... return trial_index @@ -526,11 +538,11 @@ def mark_trial_failed(self, trial_index: int) -> None: Manually mark a trial as FAILED. FAILED trials typically may be re-suggested by get_next_trials, though this is controlled by the GenerationStrategy. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._experiment.trials[trial_index].mark_failed() - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -540,11 +552,11 @@ def mark_trial_abandoned(self, trial_index: int) -> None: be re-suggested by get_next_trials, though this is controlled by the GenerationStrategy. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ self._experiment.trials[trial_index].mark_abandoned() - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -557,7 +569,7 @@ def mark_trial_early_stopped( stop the trial early. EARLY_STOPPED trials will not be re-suggested by get_next_trials. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ # First attach the new data @@ -567,7 +579,7 @@ def mark_trial_early_stopped( self._experiment.trials[trial_index].mark_early_stopped() - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save to database ... @@ -577,7 +589,7 @@ def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: the hood using the Experiment, GenerationStrategy, Metrics, and Runner attached to this AxClient along with the provided OrchestrationConfig. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. """ scheduler = Scheduler( @@ -588,7 +600,7 @@ def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None: tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, init_seconds_between_polls=options.initial_seconds_between_polls, ), - # TODO[mpolson64] Add db_settings=self._db_config when adding storage + # TODO[mpolson64] Add db_settings=self._storage_config when adding storage ) # Note: This scheduler call will handle storage internally @@ -609,7 +621,7 @@ def compute_analyses( to change incompatibly between minor versions. Users are encouraged to use the provided analyses or leave this argument as None to use the default analyses. - Saves to database on completion if db_config is present. + Saves to database on completion if storage_config is present. Returns: A list of AnalysisCards. @@ -637,7 +649,7 @@ def compute_analyses( for result in results ] - if self._db_config is not None: + if self._storage_config is not None: # TODO[mpolson64] Save cards to database ... @@ -793,12 +805,15 @@ def save_to_json_file(self, filepath: str = "ax_client_snapshot.json") -> None: Save a JSON-serialized snapshot of this `AxClient`'s settings and state to a .json file by the given path. """ - ... + with open(filepath, "w+") as file: + file.write(json.dumps(self._to_json_snapshot())) + logger.info(f"Saved JSON-serialized state of optimization to `{filepath}`.") @classmethod def load_from_json_file( cls, filepath: str = "ax_client_snapshot.json", + storage_config: StorageConfig | None = None, ) -> Self: """ Restore an `AxClient` and its state from a JSON-serialized snapshot, @@ -807,7 +822,10 @@ def load_from_json_file( Returns: The restored `AxClient`. """ - ... + with open(filepath) as file: + return cls._from_json_snapshot( + snapshot=json.loads(file.read()), storage_config=storage_config + ) def load_from_database( self, @@ -890,6 +908,7 @@ def _early_stopping_strategy_or_choose( return self._early_stopping_strategy + # -------------------- Section 5.2: Metric configuration -------------------------- def _overwrite_metric(self, metric: Metric) -> None: """ Overwrite an existing Metric on the Experiment with the provided Metric if they @@ -948,3 +967,86 @@ def _overwrite_metric(self, metric: Metric) -> None: f"Metric {metric} not found in optimization config, added as tracking " "metric." ) + + # -------------------- Section 5.3: Storage utilies ------------------------------- + def _to_json_snapshot(self) -> dict[str, Any]: + """Serialize this `AxClient` to JSON to be able to interrupt and restart + optimization and save it to file by the provided path. + + Returns: + A JSON-safe dict representation of this `AxClient`. + """ + + # If the user has supplied custom encoder registries, use them. Otherwise use + # the core encoder registries. + if ( + self._storage_config is not None + and self._storage_config.registry_bundle is not None + ): + encoder_registry = ( + self._storage_config.registry_bundle.sqa_config.json_encoder_registry + ) + class_encoder_registry = self._storage_config.registry_bundle.sqa_config.json_class_encoder_registry # noqa: E501 + else: + encoder_registry = CORE_ENCODER_REGISTRY + class_encoder_registry = CORE_CLASS_ENCODER_REGISTRY + + return { + "_type": self.__class__.__name__, + "experiment": object_to_json( + self._experiment, + encoder_registry=encoder_registry, + class_encoder_registry=class_encoder_registry, + ), + "generation_strategy": object_to_json( + self._generation_strategy, + encoder_registry=encoder_registry, + class_encoder_registry=class_encoder_registry, + ) + if self._maybe_generation_strategy is not None + else None, + } + + @classmethod + def _from_json_snapshot( + cls, + snapshot: dict[str, Any], + storage_config: StorageConfig | None = None, + ) -> Self: + # If the user has supplied custom encoder registries, use them. Otherwise use + # the core encoder registries. + if storage_config is not None and storage_config.registry_bundle is not None: + decoder_registry = ( + storage_config.registry_bundle.sqa_config.json_decoder_registry + ) + class_decoder_registry = ( + storage_config.registry_bundle.sqa_config.json_class_decoder_registry + ) + else: + decoder_registry = CORE_DECODER_REGISTRY + class_decoder_registry = CORE_CLASS_DECODER_REGISTRY + + # Decode the experiment, and generation strategy if present + experiment = object_from_json( + object_json=snapshot["experiment"], + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + + generation_strategy = ( + generation_strategy_from_json( + generation_strategy_json=snapshot["generation_strategy"], + experiment=experiment, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + if "generation_strategy" in snapshot + else None + ) + + client = cls(storage_config=storage_config) + client.set_experiment(experiment=experiment) + if generation_strategy is not None: + client.set_generation_strategy(generation_strategy=generation_strategy) + + return client diff --git a/ax/preview/api/configs.py b/ax/preview/api/configs.py index c0ca319a155..4f3827769b4 100644 --- a/ax/preview/api/configs.py +++ b/ax/preview/api/configs.py @@ -10,6 +10,7 @@ from typing import List, Mapping, Sequence from ax.preview.api.types import TParameterValue +from ax.storage.registry_bundle import RegistryBundle class ParameterType(Enum): @@ -159,5 +160,6 @@ class OrchestrationConfig: @dataclass -class DatabaseConfig: - url: str +class StorageConfig: + url: str | None = None + registry_bundle: RegistryBundle | None = None diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index eba93d17d67..8bad5df1c14 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -1071,6 +1071,63 @@ def test_predict(self) -> None: point = client.predict(points=[{"x1": 0.5}]) self.assertEqual({*point[0].keys()}, {"foo", "bar"}) + def test_json_storage(self) -> None: + client = Client() + + # Experiment with relatively complicated search space + client.configure_experiment( + experiment_config=ExperimentConfig( + parameters=[ + RangeParameterConfig( + name="x1", parameter_type=ParameterType.FLOAT, bounds=(-1, 1) + ), + RangeParameterConfig( + name="x2", parameter_type=ParameterType.INT, bounds=(-1, 1) + ), + ChoiceParameterConfig( + name="x3", + parameter_type=ParameterType.STRING, + values=["a", "b"], + ), + ChoiceParameterConfig( + name="x4", + parameter_type=ParameterType.INT, + values=[1, 2, 3], + is_ordered=True, + ), + ChoiceParameterConfig( + name="x5", parameter_type=ParameterType.INT, values=[1] + ), + ], + name="foo", + ) + ) + + # Relatively complicated optimization config + client.configure_optimization( + objective="foo + 2 * bar", outcome_constraints=["baz >= 0"] + ) + + # Specified generation strategy + client.configure_generation_strategy( + generation_strategy_config=GenerationStrategyConfig( + initialization_budget=2, + ) + ) + + # Use the Client a bit + _ = client.get_next_trials(maximum_trials=2) + + snapshot = client._to_json_snapshot() + other_client = Client._from_json_snapshot(snapshot=snapshot) + + self.assertEqual(client._experiment, other_client._experiment) + # Don't check for deep equality of GenerationStrategy since the other gs will + # not have all its attributes initialized, but ensure they have the same repr + self.assertEqual( + str(client._generation_strategy), str(other_client._generation_strategy) + ) + class DummyRunner(IRunner): @override