Skip to content

Commit

Permalink
Finalize fix of not needing to store transition criterion on steps (f…
Browse files Browse the repository at this point in the history
…acebook#1985)

Summary:

This diff:
Finalizes the fix to the storage by removing the need to unset transiton criteria and doesn't store transition criterion anymore. It can do so because in the decoder we reconstruct the generation step, which automatically fills in the relevant node fields during its init method.

Reviewed By: lena-kashtelyan

Differential Revision: D50752054
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Nov 14, 2023
1 parent e51d48b commit 6a3bc3e
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 14 deletions.
2 changes: 0 additions & 2 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,6 @@ def _unset_non_persistent_state_fields(self) -> None:
self._model = None
for s in self._steps:
s._model_spec_to_gen_from = None
# TODO: @mgarrard remove once re-enabled criterion storage
s._transition_criteria = []

def __repr__(self) -> str:
"""String representation of this generation strategy."""
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@ def test_save_and_load_generation_strategy(self) -> None:
)
second_client = AxClient(db_settings=db_settings)
second_client.load_experiment_from_database("unique_test_experiment")
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(second_client.generation_strategy, generation_strategy)

@patch(
Expand Down
1 change: 0 additions & 1 deletion ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,6 @@ def test_sqa_storage(self) -> None:
# Check that experiment and GS were saved.
exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name)
self.assertEqual(exp, experiment)
self.two_sobol_steps_GS._unset_non_persistent_state_fields()
self.assertEqual(gs, self.two_sobol_steps_GS)
scheduler.run_all_trials()
# Check that experiment and GS were saved and test reloading with reduced state.
Expand Down
5 changes: 0 additions & 5 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,11 +732,6 @@ def generation_step_from_json(
if "should_deduplicate" in generation_step_json
else False,
)
generation_step._transition_criteria = transition_criteria_from_json(
generation_step_json.pop("transition_criteria")
if "transition_criteria" in generation_step_json.keys()
else None
)
return generation_step


Expand Down
12 changes: 8 additions & 4 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.core.metric import Metric
from ax.core.runner import Runner
from ax.exceptions.core import AxStorageWarning
Expand Down Expand Up @@ -328,8 +327,14 @@ def test_EncodeDecode(self) -> None:
converted_object = converted_object.state_dict()
if isinstance(original_object, GenerationStrategy):
original_object._unset_non_persistent_state_fields()
if isinstance(original_object, BenchmarkMethod):
original_object.generation_strategy._unset_non_persistent_state_fields()
# for the test, completion criterion are set post init
# and therefore do not become transition critirion, unset
# for this specific test only
if "with_completion_criteria" in fake_func.keywords:
for step in original_object._steps:
step._transition_criteria = None
for step in converted_object._steps:
step._transition_criteria = None
try:
self.assertEqual(
original_object,
Expand Down Expand Up @@ -402,7 +407,6 @@ def test_DecodeGenerationStrategy(self) -> None:
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertGreater(len(new_generation_strategy._steps), 0)
self.assertIsInstance(new_generation_strategy._steps[0].model, Models)
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,6 @@ def test_EncodeDecodeGenerationStrategy(self) -> None:
# pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`.
gs_id=generation_strategy._db_id
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertIsNone(generation_strategy._experiment)

Expand Down

0 comments on commit 6a3bc3e

Please sign in to comment.