diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index dd8c0405c37..8936aada6fd 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -53,9 +53,9 @@ from ax.storage.sqa_store.encoder import Encoder from ax.storage.sqa_store.load import ( _get_experiment_id, - _get_generation_strategy_id, _load_experiment, _load_generation_strategy_by_experiment_name, + get_generation_strategy_id, ) from ax.storage.sqa_store.save import ( _save_experiment, @@ -150,7 +150,7 @@ def _get_experiment_and_generation_strategy_db_id( ) if not exp_id: return None, None - gs_id = _get_generation_strategy_id( + gs_id = get_generation_strategy_id( experiment_name=experiment_name, decoder=self.db_settings.decoder ) return exp_id, gs_id diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 4aed9268c3a..6c3a838d8db 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -374,9 +374,7 @@ def _load_generation_strategy_by_experiment_name( 1) Get SQLAlchemy object from DB. 2) Convert to corresponding Ax object. """ - gs_id = _get_generation_strategy_id( - experiment_name=experiment_name, decoder=decoder - ) + gs_id = get_generation_strategy_id(experiment_name=experiment_name, decoder=decoder) if gs_id is None: raise ObjectNotFoundError( f"Experiment {experiment_name} does not have a generation strategy " @@ -429,16 +427,14 @@ def _load_generation_strategy_by_id( gs_id=gs_id, decoder=decoder ) else: - gs_sqa = _get_generation_strategy_sqa(gs_id=gs_id, decoder=decoder) + gs_sqa = get_generation_strategy_sqa(gs_id=gs_id, decoder=decoder) return decoder.generation_strategy_from_sqa( gs_sqa=gs_sqa, experiment=experiment, reduced_state=reduced_state ) -def _get_generation_strategy_id( - experiment_name: str, decoder: Decoder -) -> Optional[int]: +def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> Optional[int]: """Get DB ID of the generation strategy, associated with the experiment with the given name if its in DB, return None otherwise. """ @@ -458,7 +454,7 @@ def _get_generation_strategy_id( return sqa_gs_id[0] -def _get_generation_strategy_sqa( +def get_generation_strategy_sqa( gs_id: int, decoder: Decoder, # pyre-fixme[2]: Parameter annotation cannot contain `Any`. @@ -492,7 +488,7 @@ def _get_generation_strategy_sqa_reduced_state( decoder.config.class_to_sqa_class[GeneratorRun], ) - gs_sqa = _get_generation_strategy_sqa( + gs_sqa = get_generation_strategy_sqa( gs_id=gs_id, decoder=decoder, query_options=[ @@ -528,7 +524,7 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space( gs_id: int, decoder: Decoder ) -> SQAGenerationStrategy: """Obtains most of the SQLAlchemy generation strategy object from DB.""" - return _get_generation_strategy_sqa( + return get_generation_strategy_sqa( gs_id=gs_id, decoder=decoder, query_options=[