Skip to content

Commit

Permalink
Store state of stateful models, to add to kwargs when reinstantiating
Browse files Browse the repository at this point in the history
Summary: Some of the Ax models (currently only `SobolGenerator`, but in the future Hyperband & others) have state, so it's not enough for those to record the kwargs, with which they've been instantiated. This diff introduces the logic used to store the state of those models, such that they can be reinstantiated from a generator run they produced.

Reviewed By: stevemandala

Differential Revision: D16993233

fbshipit-source-id: 5e96ec993535399e63446016c11cb2892b204b62
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Aug 27, 2019
1 parent 4624478 commit 6298b14
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 8 deletions.
20 changes: 20 additions & 0 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ax.core.types import TConfig, TModelCov, TModelMean, TModelPredict
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none


logger = get_logger("ModelBridge")
Expand Down Expand Up @@ -565,6 +566,7 @@ def gen(
model_kwargs=self._model_kwargs,
bridge_kwargs=self._bridge_kwargs,
)
self._save_model_state_if_possible()
self.fit_time_since_gen = 0.0
return gr

Expand Down Expand Up @@ -663,6 +665,24 @@ def _set_kwargs_to_save(
self._model_kwargs = model_kwargs
self._bridge_kwargs = bridge_kwargs

def _save_model_state_if_possible(self) -> None:
"""Stores state of stateful models together with the model and bridge
kwargs.
Currently the only stateful model is the `SobolGenerator`, for which it
is enough to update the stored model kwargs with that state. However, as
the number of stateful models increases, this function may require
additional logic.
"""
if (
hasattr(self, "model")
and hasattr(self, "_model_kwargs")
and self._model_kwargs is not None
):
# `ModelBridge` still has no attr. `model` after `hasattr` call,
# pyre-fixme[16]: which should've ensured its presence to typechecker.
not_none(self._model_kwargs).update(self.model._get_state())


def unwrap_observation_data(observation_data: List[ObservationData]) -> TModelPredict:
"""Converts observation data to the format for model prediction outputs.
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_sobol_GPEI_strategy(self, mock_GPEI_gen, mock_GPEI_update, mock_GPEI_in
{
"seed": None,
"deduplicate": False,
"init_position": 0,
"init_position": i + 1,
"scramble": True,
},
)
Expand Down
31 changes: 31 additions & 0 deletions ax/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Any, Dict


class Model:
"""Base class for an Ax model.
Note: the core methods each model has: `fit`, `predict`, `gen`,
`cross_validate`, and `best_point` are not present in this base class,
because the signatures for those methods vary based on the type of the model.
This class only contains the methods that all models have in common and for
which they all share the signature.
"""

def _get_state(self) -> Dict[str, Any]:
"""Obtain the state of this model, in order to be able to serialize it
and restore it from the serialized version.
While most models in Ax aren't stateful, some models, like `SobolGenerator`,
are. For Sobol, the value of the `init_position` changes throughout the
generation process as more arms are generated, and restoring the Sobol
generator with all the same settings as it was initialized with, will not
result in the same model, because the `init_position` setting changed
throughout optimization. Stateful settings like that are returned from
this method, so that a model can be reinstantiated and 'pick up where it
left off' –– more arms can be generated as if the model just continued
generation and was never interrupted and serialized.
"""
return {}
3 changes: 2 additions & 1 deletion ax/models/discrete_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import numpy as np
from ax.core.types import TConfig, TParamValue, TParamValueList
from ax.models.base import Model


class DiscreteModel:
class DiscreteModel(Model):
"""This class specifies the interface for a model based on discrete parameters.
These methods should be implemented to have access to all of the features
Expand Down
3 changes: 2 additions & 1 deletion ax/models/numpy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import numpy as np
from ax.core.types import TConfig
from ax.models.base import Model


class NumpyModel:
class NumpyModel(Model):
"""This class specifies the interface for a numpy-based model.
These methods should be implemented to have access to all of the features
Expand Down
3 changes: 2 additions & 1 deletion ax/models/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
from ax.core.types import TConfig
from ax.models.base import Model
from ax.models.model_utils import (
add_fixed_features,
rejection_sample,
Expand All @@ -13,7 +14,7 @@
)


class RandomModel:
class RandomModel(Model):
"""This class specifies the basic skeleton for a random model.
As random generators do not make use of models, they do not implement
Expand Down
11 changes: 8 additions & 3 deletions ax/models/random/sobol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
from ax.core.types import TConfig
from ax.models.base import Model
from ax.models.model_utils import tunable_feature_indices
from ax.models.random.base import RandomModel
from ax.utils.common.docutils import copy_doc
from ax.utils.stats.sobol import SobolEngine # pyre-ignore: Not handling .pyx properly


Expand Down Expand Up @@ -61,8 +63,7 @@ def init_engine(self, n_tunable_features: int) -> SobolEngine:
return self._engine

@property
# pyre-fixme[11]: Type `SobolEngine` is not defined.
def engine(self) -> SobolEngine:
def engine(self) -> Optional[SobolEngine]: # pyre-fixme[31]: not valid type.
"""Return a singleton SobolEngine."""
return self._engine

Expand Down Expand Up @@ -112,6 +113,10 @@ def gen(
self.init_position = self.engine.num_generated
return (points, weights)

@copy_doc(Model._get_state)
def _get_state(self) -> Dict[str, Any]:
return {"init_position": self.init_position}

def _gen_samples(self, n: int, tunable_d: int) -> np.ndarray:
"""Generate n samples.
Expand Down
4 changes: 4 additions & 0 deletions ax/models/tests/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class DiscreteModelTest(TestCase):
def setUp(self):
pass

def test_discrete_model_get_state(self):
discrete_model = DiscreteModel()
self.assertEqual(discrete_model._get_state(), {})

def testDiscreteModelFit(self):
discrete_model = DiscreteModel()
discrete_model.fit(
Expand Down
3 changes: 2 additions & 1 deletion ax/models/torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import torch
from ax.core.types import TConfig
from ax.models.base import Model
from torch import Tensor


class TorchModel:
class TorchModel(Model):
"""This class specifies the interface for a torch-based model.
These methods should be implemented to have access to all of the features
Expand Down
30 changes: 30 additions & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,33 @@ def test_fixed_random_seed_reproducibility(self):
ax.complete_trial(idx, branin(params.get("x1"), params.get("x2")))
trial_parameters_2 = [t.arm.parameters for t in ax.experiment.trials.values()]
self.assertEqual(trial_parameters_1, trial_parameters_2)

def test_init_position_saved(self):
ax = AxClient(random_seed=239)
ax.create_experiment(
parameters=[
{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
],
name="sobol_init_position_test",
)
for _ in range(4):
# For each generated trial, snapshot the client before generating it,
# then recreate client, regenerate the trial and compare the trial
# generated before and after snapshotting. If the state of Sobol is
# recorded correctly, the newly generated trial will be the same as
# the one generated before the snapshotting.
serialized = ax.to_json_snapshot()
params, idx = ax.get_next_trial()
ax = AxClient.from_json_snapshot(serialized)
with self.subTest(ax=ax, params=params, idx=idx):
new_params, new_idx = ax.get_next_trial()
self.assertEqual(params, new_params)
self.assertEqual(idx, new_idx)
self.assertEqual(
ax.experiment.trials[idx]._generator_run._model_kwargs[
"init_position"
],
idx + 1,
)
ax.complete_trial(idx, branin(params.get("x1"), params.get("x2")))

0 comments on commit 6298b14

Please sign in to comment.