Skip to content

Commit

Permalink
Implement sql storage
Browse files Browse the repository at this point in the history
Summary: As titled. All methods on the client that modify experiment state now also save the experiment

Differential Revision: D67162623
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 18, 2024
1 parent 020f8ce commit 4c471e8
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 65 deletions.
149 changes: 87 additions & 62 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
BaseEarlyStoppingStrategy,
PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import UnsupportedError
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.preview.api.configs import (
ExperimentConfig,
Expand All @@ -44,9 +44,11 @@
from ax.preview.api.utils.instantiation.from_string import (
optimization_config_from_string,
)
from ax.preview.api.utils.storage import db_settings_from_storage_config
from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.with_db_settings_base import WithDBSettingsBase
from ax.storage.json_store.decoder import (
generation_strategy_from_json,
object_from_json,
Expand All @@ -58,6 +60,7 @@
CORE_DECODER_REGISTRY,
CORE_ENCODER_REGISTRY,
)
from ax.storage.sqa_store.structs import DBSettings
from ax.utils.common.logger import get_logger
from ax.utils.common.random import with_rng_seed
from pyre_extensions import assert_is_instance, none_throws
Expand All @@ -66,7 +69,7 @@
logger: Logger = get_logger(__name__)


class Client:
class Client(WithDBSettingsBase):
_maybe_experiment: Experiment | None = None
_maybe_generation_strategy: GenerationStrategy | None = None
_maybe_early_stopping_strategy: BaseEarlyStoppingStrategy | None = None
Expand All @@ -86,6 +89,13 @@ def __init__(
of the experiment's results. If not provided, the random seed will not
be set, leading to potentially different results on different runs.
"""

super().__init__( # Initialize WithDBSettingsBase
db_settings=db_settings_from_storage_config(storage_config=storage_config)
if storage_config is not None
else None,
)

self._storage_config = storage_config
self._random_seed = random_seed

Expand All @@ -109,9 +119,7 @@ def configure_experiment(self, experiment_config: ExperimentConfig) -> None:

self._maybe_experiment = experiment_from_config(config=experiment_config)

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_experiment_to_db_if_possible(experiment=self._experiment)

def configure_optimization(
self,
Expand Down Expand Up @@ -154,9 +162,7 @@ def configure_optimization(
outcome_constraint_strs=outcome_constraints,
)

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_experiment_to_db_if_possible(experiment=self._experiment)

def configure_generation_strategy(
self, generation_strategy_config: GenerationStrategyConfig
Expand All @@ -177,9 +183,9 @@ def configure_generation_strategy(

self._maybe_generation_strategy = generation_strategy

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_generation_strategy_to_db_if_possible(
generation_strategy=self._generation_strategy
)

# -------------------- Section 1.1: Configure Automation ------------------------
def configure_runner(self, runner: IRunner) -> None:
Expand Down Expand Up @@ -212,9 +218,7 @@ def set_experiment(self, experiment: Experiment) -> None:
"""
self._maybe_experiment = experiment

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_experiment_to_db_if_possible(experiment=self._experiment)

def set_optimization_config(self, optimization_config: OptimizationConfig) -> None:
"""
Expand All @@ -228,9 +232,7 @@ def set_optimization_config(self, optimization_config: OptimizationConfig) -> No
"""
self._experiment.optimization_config = optimization_config

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_experiment_to_db_if_possible(experiment=self._experiment)

def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> None:
"""
Expand All @@ -244,11 +246,11 @@ def set_generation_strategy(self, generation_strategy: GenerationStrategy) -> No
"""
self._maybe_generation_strategy = generation_strategy

none_throws(self._maybe_generation_strategy)._experiment = self._experiment
self._generation_strategy._experiment = self._experiment

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_generation_strategy_to_db_if_possible(
generation_strategy=self._generation_strategy
)

def set_early_stopping_strategy(
self, early_stopping_strategy: BaseEarlyStoppingStrategy
Expand All @@ -265,10 +267,6 @@ def set_early_stopping_strategy(
"""
self._maybe_early_stopping_strategy = early_stopping_strategy

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...

def _set_runner(self, runner: Runner) -> None:
"""
This method is not part of the API and is provided (without guarantees of
Expand All @@ -281,9 +279,9 @@ def _set_runner(self, runner: Runner) -> None:
"""
self._experiment.runner = runner

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._update_runner_on_experiment_in_db_if_possible(
experiment=self._experiment, runner=runner
)

def _set_metrics(self, metrics: Sequence[Metric]) -> None:
"""
Expand All @@ -303,9 +301,7 @@ def _set_metrics(self, metrics: Sequence[Metric]) -> None:
# Check the optimization config first
self._overwrite_metric(metric=metric)

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_experiment_to_db_if_possible(experiment=self._experiment)

# -------------------- Section 2. Conduct Experiment ----------------------------
def get_next_trials(
Expand Down Expand Up @@ -366,9 +362,10 @@ def get_next_trials(

trials.append(trial)

if self._storage_config is not None:
# TODO[mpolson64] Save trial and update generation strategy
...
# Bulk save all trials to the database if possible
self._save_or_update_trials_in_db_if_possible(
experiment=self._experiment, trials=trials
)

# pyre-fixme[7]: Core Ax allows users to specify TParameterization values as
# None, but we do not allow this in the API.
Expand Down Expand Up @@ -417,9 +414,9 @@ def complete_trial(
)
self.mark_trial_failed(trial_index=trial_index)

if self._storage_config is not None:
# TODO[mpolson64] Save trial
...
self._save_or_update_trial_in_db_if_possible(
experiment=self._experiment, trial=self._experiment.trials[trial_index]
)

return self._experiment.trials[trial_index].status

Expand Down Expand Up @@ -452,9 +449,9 @@ def attach_data(
combine_with_last_data=True,
)

if self._storage_config is not None:
# TODO[mpolson64] Save trial
...
self._save_or_update_trial_in_db_if_possible(
experiment=self._experiment, trial=trial
)

# -------------------- Section 2.1 Custom trials --------------------------------
def attach_trial(
Expand All @@ -477,9 +474,9 @@ def attach_trial(
arm_names=[arm_name] if arm_name else None,
)

if self._storage_config is not None:
# TODO[mpolson64] Save trial
...
self._save_or_update_trial_in_db_if_possible(
experiment=self._experiment, trial=self._experiment.trials[trial_index]
)

return trial_index

Expand All @@ -506,8 +503,7 @@ def attach_baseline(
self._experiment.trials[trial_index], Trial
).arm

if self._storage_config is not None:
...
self._save_experiment_to_db_if_possible(experiment=self._experiment)

return trial_index

Expand Down Expand Up @@ -542,9 +538,9 @@ def mark_trial_failed(self, trial_index: int) -> None:
"""
self._experiment.trials[trial_index].mark_failed()

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_or_update_trial_in_db_if_possible(
experiment=self._experiment, trial=self._experiment.trials[trial_index]
)

def mark_trial_abandoned(self, trial_index: int) -> None:
"""
Expand All @@ -556,9 +552,9 @@ def mark_trial_abandoned(self, trial_index: int) -> None:
"""
self._experiment.trials[trial_index].mark_abandoned()

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_or_update_trial_in_db_if_possible(
experiment=self._experiment, trial=self._experiment.trials[trial_index]
)

def mark_trial_early_stopped(
self, trial_index: int, raw_data: TOutcome, progression: int | None = None
Expand All @@ -579,9 +575,9 @@ def mark_trial_early_stopped(

self._experiment.trials[trial_index].mark_early_stopped()

if self._storage_config is not None:
# TODO[mpolson64] Save to database
...
self._save_or_update_trial_in_db_if_possible(
experiment=self._experiment, trial=self._experiment.trials[trial_index]
)

def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None:
"""
Expand All @@ -600,7 +596,9 @@ def run_trials(self, maximum_trials: int, options: OrchestrationConfig) -> None:
tolerated_trial_failure_rate=options.tolerated_trial_failure_rate,
init_seconds_between_polls=options.initial_seconds_between_polls,
),
# TODO[mpolson64] Add db_settings=self._storage_config when adding storage
db_settings=db_settings_from_storage_config(self._storage_config)
if self._storage_config is not None
else None,
)

# Note: This scheduler call will handle storage internally
Expand Down Expand Up @@ -649,9 +647,9 @@ def compute_analyses(
for result in results
]

if self._storage_config is not None:
# TODO[mpolson64] Save cards to database
...
self._save_analysis_cards_to_db_if_possible(
experiment=self._experiment, analysis_cards=cards
)

return cards

Expand Down Expand Up @@ -816,28 +814,55 @@ def load_from_json_file(
storage_config: StorageConfig | None = None,
) -> Self:
"""
Restore an `AxClient` and its state from a JSON-serialized snapshot,
Restore a `Client` and its state from a JSON-serialized snapshot,
residing in a .json file by the given path.
Returns:
The restored `AxClient`.
The restored `Client`.
"""
with open(filepath) as file:
return cls._from_json_snapshot(
snapshot=json.loads(file.read()), storage_config=storage_config
)

@classmethod
def load_from_database(
self,
cls,
experiment_name: str,
storage_config: StorageConfig | None = None,
) -> Self:
"""
Restore an `AxClient` and its state from database by the given name.
Returns:
The restored `AxClient`.
"""
...
db_settings_base = WithDBSettingsBase(
db_settings=db_settings_from_storage_config(storage_config=storage_config)
if storage_config is not None
else None
)

maybe_experiment, maybe_generation_strategy = (
db_settings_base._load_experiment_and_generation_strategy(
experiment_name=experiment_name
)
)
if (experiment := maybe_experiment) is None:
raise ObjectNotFoundError(
f"Experiment {experiment_name} not found in database. Please check "
"its name is correct, check your StorageConfig is correct, or create "
"a new experiment."
)

client = cls(storage_config=storage_config)
client.set_experiment(experiment=experiment)
if maybe_generation_strategy is not None:
client.set_generation_strategy(
generation_strategy=maybe_generation_strategy
)

return client

# -------------------- Section 5: Private Methods -------------------------------
# -------------------- Section 5.1: Getters and defaults ------------------------
Expand Down
3 changes: 2 additions & 1 deletion ax/preview/api/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Mapping, Sequence
from typing import Any, Callable, List, Mapping, Sequence

from ax.preview.api.types import TParameterValue
from ax.storage.registry_bundle import RegistryBundle
Expand Down Expand Up @@ -161,5 +161,6 @@ class OrchestrationConfig:

@dataclass
class StorageConfig:
creator: Callable[..., Any] | None = None # pyre-fixme[4]
url: str | None = None
registry_bundle: RegistryBundle | None = None
Loading

0 comments on commit 4c471e8

Please sign in to comment.