diff --git a/ax/modelbridge/transforms/map_key_to_float.py b/ax/modelbridge/transforms/map_key_to_float.py new file mode 100644 index 00000000000..1ec645aff51 --- /dev/null +++ b/ax/modelbridge/transforms/map_key_to_float.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Optional, TYPE_CHECKING + +from ax.core.map_metric import MapMetric +from ax.core.observation import Observation, ObservationFeatures +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.metadata_to_range import MetadataToFloat +from ax.models.types import TConfig +from pyre_extensions import assert_is_instance + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +class MapKeyToFloat(MetadataToFloat): + DEFAULT_LOG_SCALE: bool = True + DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + config = config or {} + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.setdefault("parameters", {}), dict + ) + # TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?) + self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) + super().__init__( + search_space=search_space, + observations=observations, + modelbridge=modelbridge, + config=config, + ) + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + if not obsf.parameters: + for p in self._parameter_list: + # TODO[tiao]: can we use be p.target_value? + # (not its original intended use but could be advantageous) + obsf.parameters[p.name] = p.upper + return + super()._transform_observation_feature(obsf) diff --git a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py new file mode 100644 index 00000000000..c1ddb8c1d34 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.map_key_to_float import MapKeyToFloat +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MapKeyToFloatTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + MapKeyToFloat.DEFAULT_MAP_KEY: steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + # does not require explicitly specifying `config` + self.t = MapKeyToFloat( + observations=self.observations, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertTrue(p.log_scale) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + # test that one is able to override default config + with self.subTest(msg="override default config"): + t = MapKeyToFloat( + observations=self.observations, + config={ + "parameters": {MapKeyToFloat.DEFAULT_MAP_KEY: {"log_scale": False}} + }, + ) + self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) + + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertFalse(p.log_scale) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters), + {"height", "width", MapKeyToFloat.DEFAULT_MAP_KEY}, + ) + + p = assert_is_instance( + ss2.parameters[MapKeyToFloat.DEFAULT_MAP_KEY], RangeParameter + ) + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertTrue(p.log_scale) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + MapKeyToFloat.DEFAULT_MAP_KEY: steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) + + def test_TransformObservationFeaturesWithEmptyParameters(self) -> None: + obsf = ObservationFeatures(parameters={}) + self.t.transform_observation_features([obsf]) + + p = self.t._parameter_list[0] + self.assertEqual( + obsf, + ObservationFeatures(parameters={MapKeyToFloat.DEFAULT_MAP_KEY: p.upper}), + ) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index b7edec63bce..c35831f380a 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -319,6 +319,15 @@ Transforms :undoc-members: :show-inheritance: + +`ax.modelbridge.transforms.map\_key\_to\_float` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.map_key_to_float + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.rounding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~