From 9f39543a3ba1058d247f5c908cce534cd49030f5 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Tue, 21 Nov 2023 15:06:40 -0800 Subject: [PATCH] Create RGS class (#2009) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2009 Differential Revision: D51135077 --- ax/service/utils/with_db_settings_base.py | 4 ++-- ax/storage/sqa_store/load.py | 16 ++++++---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index c4c55359b47..d903633e2da 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -52,9 +52,9 @@ from ax.storage.sqa_store.encoder import Encoder from ax.storage.sqa_store.load import ( _get_experiment_id, - _get_generation_strategy_id, _load_experiment, _load_generation_strategy_by_experiment_name, + get_generation_strategy_id, ) from ax.storage.sqa_store.save import ( _save_experiment, @@ -149,7 +149,7 @@ def _get_experiment_and_generation_strategy_db_id( ) if not exp_id: return None, None - gs_id = _get_generation_strategy_id( + gs_id = get_generation_strategy_id( experiment_name=experiment_name, decoder=self.db_settings.decoder ) return exp_id, gs_id diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index d6cd5b92469..60c3bd94b39 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -364,9 +364,7 @@ def _load_generation_strategy_by_experiment_name( 1) Get SQLAlchemy object from DB. 2) Convert to corresponding Ax object. """ - gs_id = _get_generation_strategy_id( - experiment_name=experiment_name, decoder=decoder - ) + gs_id = get_generation_strategy_id(experiment_name=experiment_name, decoder=decoder) if gs_id is None: raise ObjectNotFoundError( f"Experiment {experiment_name} does not have a generation strategy " @@ -419,16 +417,14 @@ def _load_generation_strategy_by_id( gs_id=gs_id, decoder=decoder ) else: - gs_sqa = _get_generation_strategy_sqa(gs_id=gs_id, decoder=decoder) + gs_sqa = get_generation_strategy_sqa(gs_id=gs_id, decoder=decoder) return decoder.generation_strategy_from_sqa( gs_sqa=gs_sqa, experiment=experiment, reduced_state=reduced_state ) -def _get_generation_strategy_id( - experiment_name: str, decoder: Decoder -) -> Optional[int]: +def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> Optional[int]: """Get DB ID of the generation strategy, associated with the experiment with the given name if its in DB, return None otherwise. """ @@ -448,7 +444,7 @@ def _get_generation_strategy_id( return sqa_gs_id[0] -def _get_generation_strategy_sqa( +def get_generation_strategy_sqa( gs_id: int, decoder: Decoder, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. @@ -482,7 +478,7 @@ def _get_generation_strategy_sqa_reduced_state( decoder.config.class_to_sqa_class[GeneratorRun], ) - gs_sqa = _get_generation_strategy_sqa( + gs_sqa = get_generation_strategy_sqa( gs_id=gs_id, decoder=decoder, query_options=[ @@ -518,7 +514,7 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space( gs_id: int, decoder: Decoder ) -> SQAGenerationStrategy: """Obtains most of the SQLAlchemy generation strategy object from DB.""" - return _get_generation_strategy_sqa( + return get_generation_strategy_sqa( gs_id=gs_id, decoder=decoder, query_options=[