forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce new Transform that adds metadata as parameters in an Observ…
…ationFeature (facebook#3023) Summary: Pull Request resolved: facebook#3023 **Context:** The values corresponding to map keys are propagated as part of the ObservationFeatures' `metadata` dict field. We require a way to place it in the `parameters` dict field so that it can be used later on. This generalized transform is able to take user-specified entries from an `ObservationFeatures`'s `metadata` field and place it within its `parameters` field, and update the search space accordingly to reflect this. This implements a new transform, `MetadataToFloat`, that extracts specified fields from each `ObservationFeature` instance's metadata and incorporates them as parameters. Furthermore, it updates the search space to include the specified field as a `RangeParameter` with bounds determined by observations provided during initialization. This process involves analyzing the metadata of each observation feature and identifying relevant fields that need to be included in the search space. The bounds for these fields are then determined based on the observations provided during initialization. Differential Revision: D65430943
- Loading branch information
1 parent
c9303da
commit fa7843f
Showing
3 changed files
with
332 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from __future__ import annotations | ||
|
||
from logging import Logger | ||
from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING | ||
|
||
from ax.core import ParameterType | ||
|
||
from ax.core.observation import Observation, ObservationFeatures | ||
from ax.core.parameter import RangeParameter | ||
from ax.core.search_space import SearchSpace | ||
from ax.exceptions.core import DataRequiredError | ||
from ax.modelbridge.transforms.base import Transform | ||
from ax.models.types import TConfig | ||
from ax.utils.common.logger import get_logger | ||
from pyre_extensions import assert_is_instance, none_throws | ||
|
||
if TYPE_CHECKING: | ||
# import as module to make sphinx-autodoc-typehints happy | ||
from ax import modelbridge as modelbridge_module # noqa F401 | ||
|
||
|
||
logger: Logger = get_logger(__name__) | ||
|
||
|
||
class MetadataToFloat(Transform): | ||
""" | ||
This transform converts metadata from observation features into range (float) | ||
parameters for a search space. | ||
It allows the user to specify the `config` with `parameters` as the key, where | ||
each entry maps a metadata key to a dictionary of keyword arguments for the | ||
corresponding RangeParameter constructor. | ||
Transform is done in-place. | ||
""" | ||
|
||
DEFAULT_LOG_SCALE: bool = False | ||
DEFAULT_LOGIT_SCALE: bool = False | ||
DEFAULT_IS_FIDELITY: bool = False | ||
ENFORCE_BOUNDS: bool = False | ||
|
||
def __init__( | ||
self, | ||
search_space: SearchSpace | None = None, | ||
observations: list[Observation] | None = None, | ||
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, | ||
config: TConfig | None = None, | ||
) -> None: | ||
if observations is None or not observations: | ||
raise DataRequiredError( | ||
"`MetadataToRange` transform requires non-empty data." | ||
) | ||
config = config or {} | ||
self.parameters: dict[str, dict[str, Any]] = assert_is_instance( | ||
config.get("parameters", {}), dict | ||
) | ||
|
||
self._parameter_list: list[RangeParameter] = [] | ||
for name in self.parameters: | ||
lb = ub = None # de facto bounds | ||
for obs in observations: | ||
obsf_metadata = none_throws(obs.features.metadata) | ||
|
||
val = float(assert_is_instance(obsf_metadata[name], SupportsFloat)) | ||
|
||
lb = min(val, lb) if lb is not None else val | ||
ub = max(val, ub) if ub is not None else val | ||
|
||
lower: float = self.parameters[name].get("lower", lb) | ||
upper: float = self.parameters[name].get("upper", ub) | ||
|
||
log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) | ||
logit_scale = self.parameters[name].get( | ||
"logit_scale", self.DEFAULT_LOGIT_SCALE | ||
) | ||
digits = self.parameters[name].get("digits") | ||
is_fidelity = self.parameters[name].get( | ||
"is_fidelity", self.DEFAULT_IS_FIDELITY | ||
) | ||
|
||
target_value = self.parameters[name].get("target_value") | ||
|
||
parameter = RangeParameter( | ||
name=name, | ||
parameter_type=ParameterType.FLOAT, | ||
lower=lower, | ||
upper=upper, | ||
log_scale=log_scale, | ||
logit_scale=logit_scale, | ||
digits=digits, | ||
is_fidelity=is_fidelity, | ||
target_value=target_value, | ||
) | ||
self._parameter_list.append(parameter) | ||
|
||
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: | ||
for parameter in self._parameter_list: | ||
search_space.add_parameter(parameter.clone()) | ||
return search_space | ||
|
||
def transform_observation_features( | ||
self, observation_features: list[ObservationFeatures] | ||
) -> list[ObservationFeatures]: | ||
for obsf in observation_features: | ||
self._transform_observation_feature(obsf) | ||
return observation_features | ||
|
||
def untransform_observation_features( | ||
self, observation_features: list[ObservationFeatures] | ||
) -> list[ObservationFeatures]: | ||
for obsf in observation_features: | ||
obsf.metadata = obsf.metadata or {} | ||
_transfer( | ||
src=obsf.parameters, | ||
dst=obsf.metadata, | ||
keys=self.parameters.keys(), | ||
) | ||
return observation_features | ||
|
||
def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: | ||
_transfer( | ||
src=none_throws(obsf.metadata), | ||
dst=obsf.parameters, | ||
keys=self.parameters.keys(), | ||
) | ||
|
||
|
||
def _transfer( | ||
src: dict[str, Any], | ||
dst: dict[str, Any], | ||
keys: Iterable[str], | ||
) -> None: | ||
"""Transfer items in-place from one dictionary to another.""" | ||
for key in keys: | ||
dst[key] = src.pop(key) |
180 changes: 180 additions & 0 deletions
180
ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from copy import deepcopy | ||
from typing import Iterator | ||
|
||
import numpy as np | ||
from ax.core.observation import Observation, ObservationData, ObservationFeatures | ||
from ax.core.parameter import ParameterType, RangeParameter | ||
from ax.core.search_space import SearchSpace | ||
from ax.exceptions.core import DataRequiredError | ||
from ax.modelbridge.transforms.metadata_to_float import MetadataToFloat | ||
from ax.utils.common.testutils import TestCase | ||
from pyre_extensions import assert_is_instance | ||
|
||
|
||
WIDTHS = [2.0, 4.0, 8.0] | ||
HEIGHTS = [4.0, 2.0, 8.0] | ||
STEPS_ENDS = [1, 5, 3] | ||
|
||
|
||
def _enumerate() -> Iterator[tuple[int, float, float, float]]: | ||
yield from ( | ||
(trial_index, width, height, float(i + 1)) | ||
for trial_index, (width, height, steps_end) in enumerate( | ||
zip(WIDTHS, HEIGHTS, STEPS_ENDS) | ||
) | ||
for i in range(steps_end) | ||
) | ||
|
||
|
||
class MetadataToFloatTransformTest(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
self.search_space = SearchSpace( | ||
parameters=[ | ||
RangeParameter( | ||
name="width", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
RangeParameter( | ||
name="height", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
] | ||
) | ||
|
||
self.observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
"bar": 3.0 * steps, | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
self.observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
self.t = MetadataToFloat( | ||
observations=self.observations, | ||
config={ | ||
"parameters": {"bar": {"log_scale": True}}, | ||
}, | ||
) | ||
|
||
def test_Init(self) -> None: | ||
self.assertEqual(len(self.t._parameter_list), 1) | ||
|
||
p = self.t._parameter_list[0] | ||
|
||
# check that the parameter options are specified in a sensible manner | ||
# by default if the user does not specify them explicitly | ||
self.assertEqual(p.name, "bar") | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 3.0) | ||
self.assertEqual(p.upper, 15.0) | ||
self.assertTrue(p.log_scale) | ||
self.assertFalse(p.logit_scale) | ||
self.assertIsNone(p.digits) | ||
self.assertFalse(p.is_fidelity) | ||
self.assertIsNone(p.target_value) | ||
|
||
with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): | ||
MetadataToFloat(search_space=None, observations=None) | ||
with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): | ||
MetadataToFloat(search_space=None, observations=[]) | ||
|
||
with self.subTest("infer parameter type"): | ||
observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
"bar": int(steps), | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
t = MetadataToFloat( | ||
observations=observations, | ||
config={ | ||
"parameters": {"bar": {}}, | ||
}, | ||
) | ||
self.assertEqual(len(t._parameter_list), 1) | ||
|
||
p = t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, "bar") | ||
self.assertEqual(p.parameter_type, ParameterType.INT) | ||
self.assertEqual(p.lower, 1) | ||
self.assertEqual(p.upper, 5) | ||
self.assertFalse(p.log_scale) | ||
self.assertFalse(p.logit_scale) | ||
self.assertIsNone(p.digits) | ||
self.assertFalse(p.is_fidelity) | ||
self.assertIsNone(p.target_value) | ||
|
||
def test_TransformSearchSpace(self) -> None: | ||
ss2 = deepcopy(self.search_space) | ||
ss2 = self.t.transform_search_space(ss2) | ||
|
||
self.assertSetEqual( | ||
set(ss2.parameters.keys()), | ||
{"height", "width", "bar"}, | ||
) | ||
|
||
p = assert_is_instance(ss2.parameters["bar"], RangeParameter) | ||
|
||
self.assertEqual(p.name, "bar") | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 3.0) | ||
self.assertEqual(p.upper, 15.0) | ||
self.assertTrue(p.log_scale) | ||
self.assertFalse(p.logit_scale) | ||
self.assertIsNone(p.digits) | ||
self.assertFalse(p.is_fidelity) | ||
self.assertIsNone(p.target_value) | ||
|
||
def test_TransformObservationFeatures(self) -> None: | ||
observation_features = [obs.features for obs in self.observations] | ||
obs_ft2 = deepcopy(observation_features) | ||
obs_ft2 = self.t.transform_observation_features(obs_ft2) | ||
|
||
self.assertEqual( | ||
obs_ft2, | ||
[ | ||
ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={ | ||
"width": width, | ||
"height": height, | ||
"bar": 3.0 * steps, | ||
}, | ||
metadata={"foo": 42}, | ||
) | ||
for trial_index, width, height, steps in _enumerate() | ||
], | ||
) | ||
obs_ft2 = self.t.untransform_observation_features(obs_ft2) | ||
self.assertEqual(obs_ft2, observation_features) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters