Skip to content

Commit

Permalink
Implements MapKeyToFloat, a subclass of the MetadataToFloat Transform…
Browse files Browse the repository at this point in the history
… that provides sensible defaults for MapData (facebook#3155)

Summary:
Pull Request resolved: facebook#3155

This adds a specialization of the `MetadataToFloat` Transform, `MapKeyToFloat`, that provides sensible default settings to allow for intercepting map metric data appearing in the ObservationFeatures' metadata.

Additionally, for the purposes of specifying `fixed_features` down the line, when `_transform_observation_feature` is given an empty `ObservationFeatures` (more specifically, an `ObservationFeatures` with an empty `parameters` dict), it will populate it with the *upper bound* associated with each metadata key.

Differential Revision: D66945078
  • Loading branch information
Louis Tiao authored and facebook-github-bot committed Dec 17, 2024
1 parent fa7843f commit a9c2583
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 0 deletions.
54 changes: 54 additions & 0 deletions ax/modelbridge/transforms/map_key_to_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Optional, TYPE_CHECKING

from ax.core.map_metric import MapMetric
from ax.core.observation import Observation, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.metadata_to_range import MetadataToFloat
from ax.models.types import TConfig
from pyre_extensions import assert_is_instance

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


class MapKeyToFloat(MetadataToFloat):
DEFAULT_LOG_SCALE: bool = True
DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key

def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
config = config or {}
self.parameters: dict[str, dict[str, Any]] = assert_is_instance(
config.setdefault("parameters", {}), dict
)
# TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?)
self.parameters.setdefault(self.DEFAULT_MAP_KEY, {})
super().__init__(
search_space=search_space,
observations=observations,
modelbridge=modelbridge,
config=config,
)

def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
if not obsf.parameters:
for p in self._parameter_list:
# TODO[tiao]: can we use be p.target_value?
# (not its original intended use but could be advantageous)
obsf.parameters[p.name] = p.upper
return
super()._transform_observation_feature(obsf)
174 changes: 174 additions & 0 deletions ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from copy import deepcopy
from typing import Iterator

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.map_key_to_float import MapKeyToFloat
from ax.utils.common.testutils import TestCase
from pyre_extensions import assert_is_instance


WIDTHS = [2.0, 4.0, 8.0]
HEIGHTS = [4.0, 2.0, 8.0]
STEPS_ENDS = [1, 5, 3]


def _enumerate() -> Iterator[tuple[int, float, float, float]]:
yield from (
(trial_index, width, height, float(i + 1))
for trial_index, (width, height, steps_end) in enumerate(
zip(WIDTHS, HEIGHTS, STEPS_ENDS)
)
for i in range(steps_end)
)


class MapKeyToFloatTransformTest(TestCase):
def setUp(self) -> None:
super().setUp()

self.search_space = SearchSpace(
parameters=[
RangeParameter(
name="width",
parameter_type=ParameterType.FLOAT,
lower=1,
upper=20,
),
RangeParameter(
name="height",
parameter_type=ParameterType.FLOAT,
lower=1,
upper=20,
),
]
)

self.observations = []
for trial_index, width, height, steps in _enumerate():
obs_feat = ObservationFeatures(
trial_index=trial_index,
parameters={"width": width, "height": height},
metadata={
"foo": 42,
MapKeyToFloat.DEFAULT_MAP_KEY: steps,
},
)
obs_data = ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
)
self.observations.append(Observation(features=obs_feat, data=obs_data))

# does not require explicitly specifying `config`
self.t = MapKeyToFloat(
observations=self.observations,
)

def test_Init(self) -> None:
self.assertEqual(len(self.t._parameter_list), 1)

p = self.t._parameter_list[0]

self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 1.0)
self.assertEqual(p.upper, 5.0)
self.assertTrue(p.log_scale)

with self.subTest("infer parameter type"):
observations = []
for trial_index, width, height, steps in _enumerate():
obs_feat = ObservationFeatures(
trial_index=trial_index,
parameters={"width": width, "height": height},
metadata={
"foo": 42,
"bar": int(steps),
},
)
obs_data = ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
)
observations.append(Observation(features=obs_feat, data=obs_data))

# test that one is able to override default config
with self.subTest(msg="override default config"):
t = MapKeyToFloat(
observations=self.observations,
config={
"parameters": {MapKeyToFloat.DEFAULT_MAP_KEY: {"log_scale": False}}
},
)
self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}})

self.assertEqual(len(t._parameter_list), 1)

p = t._parameter_list[0]

self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 1.0)
self.assertEqual(p.upper, 5.0)
self.assertFalse(p.log_scale)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)

self.assertSetEqual(
set(ss2.parameters),
{"height", "width", MapKeyToFloat.DEFAULT_MAP_KEY},
)

p = assert_is_instance(
ss2.parameters[MapKeyToFloat.DEFAULT_MAP_KEY], RangeParameter
)

self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 1.0)
self.assertEqual(p.upper, 5.0)
self.assertTrue(p.log_scale)

def test_TransformObservationFeatures(self) -> None:
observation_features = [obs.features for obs in self.observations]
obs_ft2 = deepcopy(observation_features)
obs_ft2 = self.t.transform_observation_features(obs_ft2)

self.assertEqual(
obs_ft2,
[
ObservationFeatures(
trial_index=trial_index,
parameters={
"width": width,
"height": height,
MapKeyToFloat.DEFAULT_MAP_KEY: steps,
},
metadata={"foo": 42},
)
for trial_index, width, height, steps in _enumerate()
],
)
obs_ft2 = self.t.untransform_observation_features(obs_ft2)
self.assertEqual(obs_ft2, observation_features)

def test_TransformObservationFeaturesWithEmptyParameters(self) -> None:
obsf = ObservationFeatures(parameters={})
self.t.transform_observation_features([obsf])

p = self.t._parameter_list[0]
self.assertEqual(
obsf,
ObservationFeatures(parameters={MapKeyToFloat.DEFAULT_MAP_KEY: p.upper}),
)
9 changes: 9 additions & 0 deletions sphinx/source/modelbridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ Transforms
:undoc-members:
:show-inheritance:


`ax.modelbridge.transforms.map\_key\_to\_float`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.modelbridge.transforms.map_key_to_float
:members:
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.rounding`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit a9c2583

Please sign in to comment.