Skip to content

Commit

Permalink
Update MTGP transforms to use new MBM_X_trans
Browse files Browse the repository at this point in the history
Summary: Follow up to D66724547 to propagate the new transforms to MTGP models.

Differential Revision: D66726992
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 4, 2024
1 parent 4b2c0b5 commit 7f880e7
Show file tree
Hide file tree
Showing 4 changed files with 604 additions and 600 deletions.
18 changes: 4 additions & 14 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
ChoiceToNumericChoice,
OrderedChoiceToIntegerRange,
)
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
Expand Down Expand Up @@ -131,15 +130,6 @@
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
TS_trans: list[type[Transform]] = Y_trans + [SearchSpaceToChoice]

# Multi-type MTGP transforms
MT_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
Derelativize,
ConvertMetricNames,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]

# Single-type MTGP transforms
ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
Derelativize,
Expand All @@ -148,9 +138,9 @@
TaskChoiceToIntTaskChoice,
]

# Single-type MTGP transforms
Specified_Task_ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
MBM_MTGP_trans: list[type[Transform]] = MBM_X_trans + [
Derelativize,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]
Expand Down Expand Up @@ -218,7 +208,7 @@ class ModelSetup(NamedTuple):
"ST_MTGP": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
transforms=ST_MTGP_trans,
transforms=MBM_MTGP_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"BO_MIXED": ModelSetup(
Expand All @@ -241,7 +231,7 @@ class ModelSetup(NamedTuple):
"SAAS_MTGP": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
transforms=ST_MTGP_trans,
transforms=MBM_MTGP_trans,
default_model_kwargs={
"surrogate_spec": SurrogateSpec(
botorch_model_class=SaasFullyBayesianMultiTaskGP
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
from ax.modelbridge.registry import (
_extract_model_state_after_gen,
Cont_X_trans,
MBM_MTGP_trans,
MODEL_KEY_TO_MODEL_SETUP,
Models,
ST_MTGP_trans,
)
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
model_kwargs={
# this will cause an error if the model
# doesn't get fixed features
"transforms": ST_MTGP_trans,
"transforms": MBM_MTGP_trans,
**self.step_model_kwargs,
},
num_trials=1,
Expand Down
4 changes: 2 additions & 2 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models, ST_MTGP_trans
from ax.modelbridge.registry import MBM_MTGP_trans, Models
from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin
from ax.runners.synthetic import SyntheticRunner
from ax.service.scheduler import (
Expand Down Expand Up @@ -2391,7 +2391,7 @@ def test_it_works_with_multitask_models(
model_kwargs={
# this will cause and error if the model
# doesn't get fixed features
"transforms": ST_MTGP_trans,
"transforms": MBM_MTGP_trans,
"transform_configs": {
"TrialAsTask": {
"trial_level_map": {
Expand Down
Loading

0 comments on commit 7f880e7

Please sign in to comment.