Skip to content

Commit

Permalink
Introduce new Transform that adds metadata as parameters in an Observ…
Browse files Browse the repository at this point in the history
…ationFeature (facebook#3023)

Summary:
Pull Request resolved: facebook#3023

**Context:** The values corresponding to map keys are propagated as part of the ObservationFeatures' `metadata` dict field. We require a way to place it in the `parameters` dict field so that it can be used later on. This generalized transform is able to take user-specified entries from an `ObservationFeatures`'s `metadata` field and place it within its  `parameters` field, and update the search space accordingly to reflect this.

This implements a new transform, `MetadataToFloat`, that extracts specified fields from each `ObservationFeature` instance's metadata and incorporates them as parameters. Furthermore, it updates the search space to include the specified field as a `RangeParameter` with bounds determined by observations provided during initialization. This process involves analyzing the metadata of each observation feature and identifying relevant fields that need to be included in the search space. The bounds for these fields are then determined based on the observations provided during initialization.

Differential Revision: D65430943
  • Loading branch information
Louis Tiao authored and facebook-github-bot committed Dec 17, 2024
1 parent c9303da commit fa7843f
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 0 deletions.
143 changes: 143 additions & 0 deletions ax/modelbridge/transforms/metadata_to_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from logging import Logger
from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING

from ax.core import ParameterType

from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


logger: Logger = get_logger(__name__)


class MetadataToFloat(Transform):
"""
This transform converts metadata from observation features into range (float)
parameters for a search space.
It allows the user to specify the `config` with `parameters` as the key, where
each entry maps a metadata key to a dictionary of keyword arguments for the
corresponding RangeParameter constructor.
Transform is done in-place.
"""

DEFAULT_LOG_SCALE: bool = False
DEFAULT_LOGIT_SCALE: bool = False
DEFAULT_IS_FIDELITY: bool = False
ENFORCE_BOUNDS: bool = False

def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
if observations is None or not observations:
raise DataRequiredError(
"`MetadataToRange` transform requires non-empty data."
)
config = config or {}
self.parameters: dict[str, dict[str, Any]] = assert_is_instance(
config.get("parameters", {}), dict
)

self._parameter_list: list[RangeParameter] = []
for name in self.parameters:
lb = ub = None # de facto bounds
for obs in observations:
obsf_metadata = none_throws(obs.features.metadata)

val = float(assert_is_instance(obsf_metadata[name], SupportsFloat))

lb = min(val, lb) if lb is not None else val
ub = max(val, ub) if ub is not None else val

lower: float = self.parameters[name].get("lower", lb)
upper: float = self.parameters[name].get("upper", ub)

log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE)
logit_scale = self.parameters[name].get(
"logit_scale", self.DEFAULT_LOGIT_SCALE
)
digits = self.parameters[name].get("digits")
is_fidelity = self.parameters[name].get(
"is_fidelity", self.DEFAULT_IS_FIDELITY
)

target_value = self.parameters[name].get("target_value")

parameter = RangeParameter(
name=name,
parameter_type=ParameterType.FLOAT,
lower=lower,
upper=upper,
log_scale=log_scale,
logit_scale=logit_scale,
digits=digits,
is_fidelity=is_fidelity,
target_value=target_value,
)
self._parameter_list.append(parameter)

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
for parameter in self._parameter_list:
search_space.add_parameter(parameter.clone())
return search_space

def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
self._transform_observation_feature(obsf)
return observation_features

def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
obsf.metadata = obsf.metadata or {}
_transfer(
src=obsf.parameters,
dst=obsf.metadata,
keys=self.parameters.keys(),
)
return observation_features

def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
_transfer(
src=none_throws(obsf.metadata),
dst=obsf.parameters,
keys=self.parameters.keys(),
)


def _transfer(
src: dict[str, Any],
dst: dict[str, Any],
keys: Iterable[str],
) -> None:
"""Transfer items in-place from one dictionary to another."""
for key in keys:
dst[key] = src.pop(key)
180 changes: 180 additions & 0 deletions ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from copy import deepcopy
from typing import Iterator

import numpy as np
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.metadata_to_float import MetadataToFloat
from ax.utils.common.testutils import TestCase
from pyre_extensions import assert_is_instance


WIDTHS = [2.0, 4.0, 8.0]
HEIGHTS = [4.0, 2.0, 8.0]
STEPS_ENDS = [1, 5, 3]


def _enumerate() -> Iterator[tuple[int, float, float, float]]:
yield from (
(trial_index, width, height, float(i + 1))
for trial_index, (width, height, steps_end) in enumerate(
zip(WIDTHS, HEIGHTS, STEPS_ENDS)
)
for i in range(steps_end)
)


class MetadataToFloatTransformTest(TestCase):
def setUp(self) -> None:
super().setUp()

self.search_space = SearchSpace(
parameters=[
RangeParameter(
name="width",
parameter_type=ParameterType.FLOAT,
lower=1,
upper=20,
),
RangeParameter(
name="height",
parameter_type=ParameterType.FLOAT,
lower=1,
upper=20,
),
]
)

self.observations = []
for trial_index, width, height, steps in _enumerate():
obs_feat = ObservationFeatures(
trial_index=trial_index,
parameters={"width": width, "height": height},
metadata={
"foo": 42,
"bar": 3.0 * steps,
},
)
obs_data = ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
)
self.observations.append(Observation(features=obs_feat, data=obs_data))

self.t = MetadataToFloat(
observations=self.observations,
config={
"parameters": {"bar": {"log_scale": True}},
},
)

def test_Init(self) -> None:
self.assertEqual(len(self.t._parameter_list), 1)

p = self.t._parameter_list[0]

# check that the parameter options are specified in a sensible manner
# by default if the user does not specify them explicitly
self.assertEqual(p.name, "bar")
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 3.0)
self.assertEqual(p.upper, 15.0)
self.assertTrue(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)

with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"):
MetadataToFloat(search_space=None, observations=None)
with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"):
MetadataToFloat(search_space=None, observations=[])

with self.subTest("infer parameter type"):
observations = []
for trial_index, width, height, steps in _enumerate():
obs_feat = ObservationFeatures(
trial_index=trial_index,
parameters={"width": width, "height": height},
metadata={
"foo": 42,
"bar": int(steps),
},
)
obs_data = ObservationData(
metric_names=[], means=np.array([]), covariance=np.empty((0, 0))
)
observations.append(Observation(features=obs_feat, data=obs_data))

t = MetadataToFloat(
observations=observations,
config={
"parameters": {"bar": {}},
},
)
self.assertEqual(len(t._parameter_list), 1)

p = t._parameter_list[0]

self.assertEqual(p.name, "bar")
self.assertEqual(p.parameter_type, ParameterType.INT)
self.assertEqual(p.lower, 1)
self.assertEqual(p.upper, 5)
self.assertFalse(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
ss2 = self.t.transform_search_space(ss2)

self.assertSetEqual(
set(ss2.parameters.keys()),
{"height", "width", "bar"},
)

p = assert_is_instance(ss2.parameters["bar"], RangeParameter)

self.assertEqual(p.name, "bar")
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 3.0)
self.assertEqual(p.upper, 15.0)
self.assertTrue(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)

def test_TransformObservationFeatures(self) -> None:
observation_features = [obs.features for obs in self.observations]
obs_ft2 = deepcopy(observation_features)
obs_ft2 = self.t.transform_observation_features(obs_ft2)

self.assertEqual(
obs_ft2,
[
ObservationFeatures(
trial_index=trial_index,
parameters={
"width": width,
"height": height,
"bar": 3.0 * steps,
},
metadata={"foo": 42},
)
for trial_index, width, height, steps in _enumerate()
],
)
obs_ft2 = self.t.untransform_observation_features(obs_ft2)
self.assertEqual(obs_ft2, observation_features)
9 changes: 9 additions & 0 deletions sphinx/source/modelbridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ Transforms
:undoc-members:
:show-inheritance:


`ax.modelbridge.transforms.metadata\_to\_float`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.modelbridge.transforms.metadata_to_float
:members:
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.rounding`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit fa7843f

Please sign in to comment.