Skip to content

Commit

Permalink
Create RGS class (#2009)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2009

Differential Revision: D51135077
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 21, 2023
1 parent 703245e commit 9f39543
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
4 changes: 2 additions & 2 deletions ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
from ax.storage.sqa_store.encoder import Encoder
from ax.storage.sqa_store.load import (
_get_experiment_id,
_get_generation_strategy_id,
_load_experiment,
_load_generation_strategy_by_experiment_name,
get_generation_strategy_id,
)
from ax.storage.sqa_store.save import (
_save_experiment,
Expand Down Expand Up @@ -149,7 +149,7 @@ def _get_experiment_and_generation_strategy_db_id(
)
if not exp_id:
return None, None
gs_id = _get_generation_strategy_id(
gs_id = get_generation_strategy_id(
experiment_name=experiment_name, decoder=self.db_settings.decoder
)
return exp_id, gs_id
Expand Down
16 changes: 6 additions & 10 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ def _load_generation_strategy_by_experiment_name(
1) Get SQLAlchemy object from DB.
2) Convert to corresponding Ax object.
"""
gs_id = _get_generation_strategy_id(
experiment_name=experiment_name, decoder=decoder
)
gs_id = get_generation_strategy_id(experiment_name=experiment_name, decoder=decoder)
if gs_id is None:
raise ObjectNotFoundError(
f"Experiment {experiment_name} does not have a generation strategy "
Expand Down Expand Up @@ -419,16 +417,14 @@ def _load_generation_strategy_by_id(
gs_id=gs_id, decoder=decoder
)
else:
gs_sqa = _get_generation_strategy_sqa(gs_id=gs_id, decoder=decoder)
gs_sqa = get_generation_strategy_sqa(gs_id=gs_id, decoder=decoder)

return decoder.generation_strategy_from_sqa(
gs_sqa=gs_sqa, experiment=experiment, reduced_state=reduced_state
)


def _get_generation_strategy_id(
experiment_name: str, decoder: Decoder
) -> Optional[int]:
def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> Optional[int]:
"""Get DB ID of the generation strategy, associated with the experiment
with the given name if its in DB, return None otherwise.
"""
Expand All @@ -448,7 +444,7 @@ def _get_generation_strategy_id(
return sqa_gs_id[0]


def _get_generation_strategy_sqa(
def get_generation_strategy_sqa(
gs_id: int,
decoder: Decoder,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
Expand Down Expand Up @@ -482,7 +478,7 @@ def _get_generation_strategy_sqa_reduced_state(
decoder.config.class_to_sqa_class[GeneratorRun],
)

gs_sqa = _get_generation_strategy_sqa(
gs_sqa = get_generation_strategy_sqa(
gs_id=gs_id,
decoder=decoder,
query_options=[
Expand Down Expand Up @@ -518,7 +514,7 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space(
gs_id: int, decoder: Decoder
) -> SQAGenerationStrategy:
"""Obtains most of the SQLAlchemy generation strategy object from DB."""
return _get_generation_strategy_sqa(
return get_generation_strategy_sqa(
gs_id=gs_id,
decoder=decoder,
query_options=[
Expand Down

0 comments on commit 9f39543

Please sign in to comment.