Skip to content

Commit

Permalink
Update the default set of Ax transforms used in MBM (#3144)
Browse files Browse the repository at this point in the history
Summary:

This diff adds a new set of transforms (to be used by default in MBM based models) that replaces `IntToFloat` with `LogIntToFloat` and `UnitX` with `Normalize`. The new set of transforms avoid using continuous relaxation for non log-scale discrete parameter, which consistently delivers improved optimization performance on mixed integer benchmark problems.

This diff only updates single task model registry entries. I will follow it up with additional diffs to propagage the changes in multiple stages.

Differential Revision: D66724547
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 4, 2024
1 parent 258400c commit 4b2c0b5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
32 changes: 29 additions & 3 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
from ax.modelbridge.transforms.int_to_float import IntToFloat
from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat
from ax.modelbridge.transforms.ivw import IVW
from ax.modelbridge.transforms.log import Log
from ax.modelbridge.transforms.logit import Logit
Expand Down Expand Up @@ -76,6 +76,10 @@

logger: Logger = get_logger(__name__)

# This set of transforms uses continuous relaxation to handle discrete parameters.
# All candidate generation is done in the continuous space, and the generated
# candidates are rounded to fit the original search space. This is can be
# suboptimal when there are discrete parameters with a small number of options.
Cont_X_trans: list[type[Transform]] = [
FillMissingParameters,
RemoveFixed,
Expand All @@ -87,8 +91,30 @@
UnitX,
]

# This is a modification of Cont_X_trans that aims to avoid continuous relaxation
# where possible. It replaces IntToFloat with LogIntToFloat, which is only transforms
# log-scale integer parameters, which still use continuous relaxation. Other discrete
# transforms will remain discrete. When used with MBM, a Normalize input transform
# will be added to replace the UnitX transform. This setup facilitates the use of
# optimize_acqf_mixed_alternating, which is a more efficient acquisition function
# optimizer for mixed discrete/continuous problems.
MBM_X_trans: list[type[Transform]] = [
FillMissingParameters,
RemoveFixed,
OrderedChoiceToIntegerRange,
OneHot,
LogIntToFloat,
Log,
Logit,
]


Discrete_X_trans: list[type[Transform]] = [IntRangeToChoice]

# This is a modification of Cont_X_trans that replaces OneHot and
# OrderedChoiceToIntegerRange with ChoiceToNumericChoice. This results in retaining
# all choice parameters as discrete, while using continuous relaxation for integer
# valued RangeParameters.
Mixed_transforms: list[type[Transform]] = [
FillMissingParameters,
RemoveFixed,
Expand Down Expand Up @@ -155,7 +181,7 @@ class ModelSetup(NamedTuple):
"BoTorch": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
transforms=Cont_X_trans + Y_trans,
transforms=MBM_X_trans + Y_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"Legacy_GPEI": ModelSetup(
Expand Down Expand Up @@ -204,7 +230,7 @@ class ModelSetup(NamedTuple):
"SAASBO": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
transforms=Cont_X_trans + Y_trans,
transforms=MBM_X_trans + Y_trans,
default_model_kwargs={
"surrogate_spec": SurrogateSpec(
botorch_model_class=SaasFullyBayesianSingleTaskGP
Expand Down
5 changes: 2 additions & 3 deletions ax/modelbridge/tests/test_dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
choose_generation_strategy,
DEFAULT_BAYESIAN_PARALLELISM,
)
from ax.modelbridge.factory import Cont_X_trans, Y_trans
from ax.modelbridge.registry import Mixed_transforms, Models
from ax.modelbridge.registry import MBM_X_trans, Mixed_transforms, Models, Y_trans
from ax.modelbridge.transforms.log_y import LogY
from ax.modelbridge.transforms.winsorize import Winsorize
from ax.models.winsorization_config import WinsorizationConfig
Expand All @@ -44,7 +43,7 @@ class TestDispatchUtils(TestCase):

@mock_botorch_optimize
def test_choose_generation_strategy(self) -> None:
expected_transforms = [Winsorize] + Cont_X_trans + Y_trans
expected_transforms = [Winsorize] + MBM_X_trans + Y_trans
expected_transform_configs = {
"Winsorize": {"derelativize_with_raw_status_quo": False},
"Derelativize": {"use_raw_status_quo": False},
Expand Down
12 changes: 9 additions & 3 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Models
from ax.modelbridge.registry import Cont_X_trans, Models
from ax.runners.synthetic import SyntheticRunner

from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.best_point import (
get_best_parameters_from_model_predictions_with_trial_index,
Expand Down Expand Up @@ -220,7 +219,14 @@ def get_client_with_simple_discrete_moo_problem(
gs = GenerationStrategy(
steps=[
GenerationStep(model=Models.SOBOL, num_trials=3),
GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
# To avoid search space exhausted errors.
"transforms": Cont_X_trans,
},
),
]
)

Expand Down

0 comments on commit 4b2c0b5

Please sign in to comment.