Skip to content

Commit

Permalink
Add LogIntToFloat transform (#3091)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3091

This is a simple subclass of `IntToFloat` that only transforms log-scale parameters.

Replacing `IntToFloat` with `LogIntToFloat` will avoid unnecessary use of continuous relaxation across the board, and allow us to utilize the various optimizers available in `Acquisition.optimize`.

Additional context:
With log-scale parameters, we have two options: transform them in Ax or transform them in BoTorch. Transforming them in Ax leads to both modeling and optimizing the parameter in the log-scale (good), but transforming in BoTorch leads to modeling in log-scale but optimizing in the raw scale (not ideal) and also introduces `TransformedPosterior` and some incompatibilities it brings. So, we want to transform log-scale parameters in Ax.
Since log of an int parameter is no longer int, we have to relax them. But we don't want to relax any other int parameters, so we don't want to use `IntToFloat`. `LogIntToFloat` makes it possible to use continuous relaxation only for the log-scale parameters, which is a good step in the right direction.

Reviewed By: dme65

Differential Revision: D66244582

fbshipit-source-id: a1f24e8ad1f8af66ec34db959a6616c5e0ed8dd8
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 3, 2024
1 parent 19fdaff commit 258400c
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SearchSpace(Base):

def __init__(
self,
parameters: list[Parameter],
parameters: Sequence[Parameter],
parameter_constraints: list[ParameterConstraint] | None = None,
) -> None:
"""Initialize SearchSpace
Expand Down
55 changes: 49 additions & 6 deletions ax/modelbridge/transforms/int_to_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.rounding import (
contains_constrained_integer,
Expand Down Expand Up @@ -65,18 +66,22 @@ def __init__(
self.min_choices: int = checked_cast(int, config.get("min_choices", 0))

# Identify parameters that should be transformed
self.transform_parameters: set[str] = {
self.transform_parameters: set[str] = self._get_transform_parameters()
if contains_constrained := contains_constrained_integer(
self.search_space, self.transform_parameters
):
self.rounding = "randomized"
self.contains_constrained_integer: bool = contains_constrained

def _get_transform_parameters(self) -> set[str]:
"""Identify parameters that should be transformed."""
return {
p_name
for p_name, p in self.search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and ((p.cardinality() >= self.min_choices) or p.log_scale)
}
if contains_constrained_integer(self.search_space, self.transform_parameters):
self.rounding = "randomized"
self.contains_constrained_integer: bool = True
else:
self.contains_constrained_integer: bool = False

def transform_observation_features(
self, observation_features: list[ObservationFeatures]
Expand Down Expand Up @@ -183,3 +188,41 @@ def untransform_observation_features(
obsf.parameters[p_name] = rounded_parameters[p_name]

return observation_features


class LogIntToFloat(IntToFloat):
"""Convert a log-scale RangeParameter of type int to type float.
The behavior of this transform mirrors ``IntToFloat`` with the key difference
being that it only operates on log-scale parameters.
"""

def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
if config is not None and "min_choices" in config:
raise UserInputError(
"`min_choices` cannot be specified for `LogIntToFloat` transform. "
)
super().__init__(
search_space=search_space,
observations=observations,
modelbridge=modelbridge,
config=config,
)
# Delete the attribute to avoid it presenting a misleading value.
del self.min_choices

def _get_transform_parameters(self) -> set[str]:
"""Identify parameters that should be transformed."""
return {
p_name
for p_name, p in self.search_space.parameters.items()
if isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and p.log_scale
}
34 changes: 32 additions & 2 deletions ax/modelbridge/transforms/tests/test_int_to_float_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.core.parameter_constraint import OrderConstraint, SumConstraint
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.transforms.int_to_float import IntToFloat
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import get_robust_search_space
Expand Down Expand Up @@ -324,3 +324,33 @@ def test_w_parameter_distributions(self) -> None:
)
with self.assertRaisesRegex(UnsupportedError, "transform is not supported"):
t.transform_search_space(rss)


class LogIntToFloatTransformTest(TestCase):
def test_log_int_to_float(self) -> None:
parameters = [
RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.INT),
RangeParameter("y", lower=1, upper=50, parameter_type=ParameterType.INT),
RangeParameter(
"z", lower=1, upper=50, parameter_type=ParameterType.INT, log_scale=True
),
]
search_space = SearchSpace(parameters=parameters)
with self.assertRaisesRegex(UserInputError, "min_choices"):
LogIntToFloat(search_space=search_space, config={"min_choices": 5})
t = LogIntToFloat(search_space=search_space)
self.assertFalse(hasattr(t, "min_choices"))
self.assertEqual(t.transform_parameters, {"z"})
t_ss = t.transform_search_space(search_space)
self.assertEqual(t_ss.parameters["x"], parameters[0])
self.assertEqual(t_ss.parameters["y"], parameters[1])
self.assertEqual(
t_ss.parameters["z"],
RangeParameter(
name="z",
lower=0.50001,
upper=50.49999,
parameter_type=ParameterType.FLOAT,
log_scale=True,
),
)
3 changes: 2 additions & 1 deletion ax/storage/transform_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
from ax.modelbridge.transforms.int_to_float import IntToFloat
from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat
from ax.modelbridge.transforms.ivw import IVW
from ax.modelbridge.transforms.log import Log
from ax.modelbridge.transforms.log_y import LogY
Expand Down Expand Up @@ -95,6 +95,7 @@
TimeAsFeature: 27,
TransformToNewSQ: 28,
FillMissingParameters: 29,
LogIntToFloat: 30,
}

"""
Expand Down

0 comments on commit 258400c

Please sign in to comment.