forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implements MapKeyToFloat, a subclass of the MetadataToFloat Transform…
… that provides sensible defaults for MapData (facebook#3155) Summary: Pull Request resolved: facebook#3155 This adds a specialization of the `MetadataToFloat` Transform, `MapKeyToFloat`, that provides sensible default settings to allow for intercepting map metric data appearing in the ObservationFeatures' metadata. Additionally, for the purposes of specifying `fixed_features` down the line, when `_transform_observation_feature` is given an empty `ObservationFeatures` (more specifically, an `ObservationFeatures` with an empty `parameters` dict), it will populate it with the *upper bound* associated with each metadata key. Differential Revision: D66945078
- Loading branch information
1 parent
fa7843f
commit a9c2583
Showing
3 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Any, Optional, TYPE_CHECKING | ||
|
||
from ax.core.map_metric import MapMetric | ||
from ax.core.observation import Observation, ObservationFeatures | ||
from ax.core.search_space import SearchSpace | ||
from ax.modelbridge.transforms.metadata_to_range import MetadataToFloat | ||
from ax.models.types import TConfig | ||
from pyre_extensions import assert_is_instance | ||
|
||
if TYPE_CHECKING: | ||
# import as module to make sphinx-autodoc-typehints happy | ||
from ax import modelbridge as modelbridge_module # noqa F401 | ||
|
||
|
||
class MapKeyToFloat(MetadataToFloat): | ||
DEFAULT_LOG_SCALE: bool = True | ||
DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key | ||
|
||
def __init__( | ||
self, | ||
search_space: SearchSpace | None = None, | ||
observations: list[Observation] | None = None, | ||
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, | ||
config: TConfig | None = None, | ||
) -> None: | ||
config = config or {} | ||
self.parameters: dict[str, dict[str, Any]] = assert_is_instance( | ||
config.setdefault("parameters", {}), dict | ||
) | ||
# TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?) | ||
self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) | ||
super().__init__( | ||
search_space=search_space, | ||
observations=observations, | ||
modelbridge=modelbridge, | ||
config=config, | ||
) | ||
|
||
def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: | ||
if not obsf.parameters: | ||
for p in self._parameter_list: | ||
# TODO[tiao]: can we use be p.target_value? | ||
# (not its original intended use but could be advantageous) | ||
obsf.parameters[p.name] = p.upper | ||
return | ||
super()._transform_observation_feature(obsf) |
174 changes: 174 additions & 0 deletions
174
ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from copy import deepcopy | ||
from typing import Iterator | ||
|
||
import numpy as np | ||
from ax.core.observation import Observation, ObservationData, ObservationFeatures | ||
from ax.core.parameter import ParameterType, RangeParameter | ||
from ax.core.search_space import SearchSpace | ||
from ax.modelbridge.transforms.map_key_to_float import MapKeyToFloat | ||
from ax.utils.common.testutils import TestCase | ||
from pyre_extensions import assert_is_instance | ||
|
||
|
||
WIDTHS = [2.0, 4.0, 8.0] | ||
HEIGHTS = [4.0, 2.0, 8.0] | ||
STEPS_ENDS = [1, 5, 3] | ||
|
||
|
||
def _enumerate() -> Iterator[tuple[int, float, float, float]]: | ||
yield from ( | ||
(trial_index, width, height, float(i + 1)) | ||
for trial_index, (width, height, steps_end) in enumerate( | ||
zip(WIDTHS, HEIGHTS, STEPS_ENDS) | ||
) | ||
for i in range(steps_end) | ||
) | ||
|
||
|
||
class MapKeyToFloatTransformTest(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
self.search_space = SearchSpace( | ||
parameters=[ | ||
RangeParameter( | ||
name="width", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
RangeParameter( | ||
name="height", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
] | ||
) | ||
|
||
self.observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
MapKeyToFloat.DEFAULT_MAP_KEY: steps, | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
self.observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
# does not require explicitly specifying `config` | ||
self.t = MapKeyToFloat( | ||
observations=self.observations, | ||
) | ||
|
||
def test_Init(self) -> None: | ||
self.assertEqual(len(self.t._parameter_list), 1) | ||
|
||
p = self.t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertTrue(p.log_scale) | ||
|
||
with self.subTest("infer parameter type"): | ||
observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
"bar": int(steps), | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
# test that one is able to override default config | ||
with self.subTest(msg="override default config"): | ||
t = MapKeyToFloat( | ||
observations=self.observations, | ||
config={ | ||
"parameters": {MapKeyToFloat.DEFAULT_MAP_KEY: {"log_scale": False}} | ||
}, | ||
) | ||
self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) | ||
|
||
self.assertEqual(len(t._parameter_list), 1) | ||
|
||
p = t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertFalse(p.log_scale) | ||
|
||
def test_TransformSearchSpace(self) -> None: | ||
ss2 = deepcopy(self.search_space) | ||
ss2 = self.t.transform_search_space(ss2) | ||
|
||
self.assertSetEqual( | ||
set(ss2.parameters), | ||
{"height", "width", MapKeyToFloat.DEFAULT_MAP_KEY}, | ||
) | ||
|
||
p = assert_is_instance( | ||
ss2.parameters[MapKeyToFloat.DEFAULT_MAP_KEY], RangeParameter | ||
) | ||
|
||
self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertTrue(p.log_scale) | ||
|
||
def test_TransformObservationFeatures(self) -> None: | ||
observation_features = [obs.features for obs in self.observations] | ||
obs_ft2 = deepcopy(observation_features) | ||
obs_ft2 = self.t.transform_observation_features(obs_ft2) | ||
|
||
self.assertEqual( | ||
obs_ft2, | ||
[ | ||
ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={ | ||
"width": width, | ||
"height": height, | ||
MapKeyToFloat.DEFAULT_MAP_KEY: steps, | ||
}, | ||
metadata={"foo": 42}, | ||
) | ||
for trial_index, width, height, steps in _enumerate() | ||
], | ||
) | ||
obs_ft2 = self.t.untransform_observation_features(obs_ft2) | ||
self.assertEqual(obs_ft2, observation_features) | ||
|
||
def test_TransformObservationFeaturesWithEmptyParameters(self) -> None: | ||
obsf = ObservationFeatures(parameters={}) | ||
self.t.transform_observation_features([obsf]) | ||
|
||
p = self.t._parameter_list[0] | ||
self.assertEqual( | ||
obsf, | ||
ObservationFeatures(parameters={MapKeyToFloat.DEFAULT_MAP_KEY: p.upper}), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters