diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index b8ade5f2199..defb1ffc582 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -330,8 +330,6 @@ def _unset_non_persistent_state_fields(self) -> None: self._model = None for s in self._steps: s._model_spec_to_gen_from = None - # TODO: @mgarrard remove once re-enabled criterion storage - s._transition_criteria = [] def __repr__(self) -> str: """String representation of this generation strategy.""" diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index aa292139a1f..10440f2102e 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -614,7 +614,6 @@ def test_save_and_load_generation_strategy(self) -> None: ) second_client = AxClient(db_settings=db_settings) second_client.load_experiment_from_database("unique_test_experiment") - generation_strategy._unset_non_persistent_state_fields() self.assertEqual(second_client.generation_strategy, generation_strategy) @patch( diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 0657668d62e..8533294f986 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -777,7 +777,6 @@ def test_sqa_storage(self) -> None: # Check that experiment and GS were saved. exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name) self.assertEqual(exp, experiment) - self.two_sobol_steps_GS._unset_non_persistent_state_fields() self.assertEqual(gs, self.two_sobol_steps_GS) scheduler.run_all_trials() # Check that experiment and GS were saved and test reloading with reduced state. diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index dee34736a01..898248e96e6 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -732,11 +732,6 @@ def generation_step_from_json( if "should_deduplicate" in generation_step_json else False, ) - generation_step._transition_criteria = transition_criteria_from_json( - generation_step_json.pop("transition_criteria") - if "transition_criteria" in generation_step_json.keys() - else None - ) return generation_step diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index b6c5a78f950..a90d708a26f 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -10,7 +10,6 @@ import numpy as np import torch -from ax.benchmark.benchmark_method import BenchmarkMethod from ax.core.metric import Metric from ax.core.runner import Runner from ax.exceptions.core import AxStorageWarning @@ -328,8 +327,14 @@ def test_EncodeDecode(self) -> None: converted_object = converted_object.state_dict() if isinstance(original_object, GenerationStrategy): original_object._unset_non_persistent_state_fields() - if isinstance(original_object, BenchmarkMethod): - original_object.generation_strategy._unset_non_persistent_state_fields() + # for the test, completion criterion are set post init + # and therefore do not become transition critirion, unset + # for this specific test only + if "with_completion_criteria" in fake_func.keywords: + for step in original_object._steps: + step._transition_criteria = None + for step in converted_object._steps: + step._transition_criteria = None try: self.assertEqual( original_object, @@ -402,7 +407,6 @@ def test_DecodeGenerationStrategy(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) - generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertGreater(len(new_generation_strategy._steps), 0) self.assertIsInstance(new_generation_strategy._steps[0].model, Models) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 17b4aa436ca..61e81e59ea0 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1226,7 +1226,6 @@ def test_EncodeDecodeGenerationStrategy(self) -> None: # pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`. gs_id=generation_strategy._db_id ) - generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsNone(generation_strategy._experiment)