Skip to content

Commit

Permalink
Update MTGP transforms to use new MBM_X_trans (#3146)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3146

Follow up to D66724547 to propagate the new transforms to MTGP models.

Reviewed By: lena-kashtelyan

Differential Revision: D66726992

fbshipit-source-id: 1fdeaba5893e79bca909a7a33d14ee4e7a4a046b
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 9, 2024
1 parent 6cf2ae3 commit be685a7
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
18 changes: 4 additions & 14 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
ChoiceToNumericChoice,
OrderedChoiceToIntegerRange,
)
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
Expand Down Expand Up @@ -131,15 +130,6 @@
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
TS_trans: list[type[Transform]] = Y_trans + [SearchSpaceToChoice]

# Multi-type MTGP transforms
MT_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
Derelativize,
ConvertMetricNames,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]

# Single-type MTGP transforms
ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
Derelativize,
Expand All @@ -148,9 +138,9 @@
TaskChoiceToIntTaskChoice,
]

# Single-type MTGP transforms
Specified_Task_ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [
MBM_MTGP_trans: list[type[Transform]] = MBM_X_trans + [
Derelativize,
TrialAsTask,
StratifiedStandardizeY,
TaskChoiceToIntTaskChoice,
]
Expand Down Expand Up @@ -218,7 +208,7 @@ class ModelSetup(NamedTuple):
"ST_MTGP": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
transforms=ST_MTGP_trans,
transforms=MBM_MTGP_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"BO_MIXED": ModelSetup(
Expand All @@ -241,7 +231,7 @@ class ModelSetup(NamedTuple):
"SAAS_MTGP": ModelSetup(
bridge_class=TorchModelBridge,
model_class=ModularBoTorchModel,
transforms=ST_MTGP_trans,
transforms=MBM_MTGP_trans,
default_model_kwargs={
"surrogate_spec": SurrogateSpec(
botorch_model_class=SaasFullyBayesianMultiTaskGP
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
from ax.modelbridge.registry import (
_extract_model_state_after_gen,
Cont_X_trans,
MBM_MTGP_trans,
MODEL_KEY_TO_MODEL_SETUP,
Models,
ST_MTGP_trans,
)
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transition_criterion import (
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
model_kwargs={
# this will cause an error if the model
# doesn't get fixed features
"transforms": ST_MTGP_trans,
"transforms": MBM_MTGP_trans,
**self.step_model_kwargs,
},
num_trials=1,
Expand Down
4 changes: 2 additions & 2 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models, ST_MTGP_trans
from ax.modelbridge.registry import MBM_MTGP_trans, Models
from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin
from ax.runners.synthetic import SyntheticRunner
from ax.service.scheduler import (
Expand Down Expand Up @@ -2391,7 +2391,7 @@ def test_it_works_with_multitask_models(
model_kwargs={
# this will cause and error if the model
# doesn't get fixed features
"transforms": ST_MTGP_trans,
"transforms": MBM_MTGP_trans,
"transform_configs": {
"TrialAsTask": {
"trial_level_map": {
Expand Down
19 changes: 17 additions & 2 deletions tutorials/multi_task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,30 @@
"from ax.core.search_space import SearchSpace\n",
"from ax.metrics.hartmann6 import Hartmann6Metric\n",
"from ax.modelbridge.factory import get_sobol\n",
"from ax.modelbridge.registry import Models, MT_MTGP_trans, ST_MTGP_trans\n",
"from ax.modelbridge.registry import Models, MBM_X_trans, ST_MTGP_trans\n",
"from ax.modelbridge.torch import TorchModelBridge\n",
"from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment\n",
"from ax.modelbridge.transforms.derelativize import Derelativize\n",
"from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames\n",
"from ax.modelbridge.transforms.trial_as_task import TrialAsTask\n",
"from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY\n",
"from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice\n",
"from ax.plot.diagnostic import interact_batch_comparison\n",
"from ax.runners.synthetic import SyntheticRunner\n",
"from ax.utils.common.typeutils import checked_cast\n",
"from ax.utils.notebook.plotting import init_notebook_plotting, render\n",
"\n",
"init_notebook_plotting()"
"init_notebook_plotting()\n",
"\n",
"# Transforms for pre-processing the data from a multi-type experiment to \n",
"# construct a multi-task GP model.\n",
"MT_MTGP_trans = MBM_X_trans + [\n",
" Derelativize,\n",
" ConvertMetricNames,\n",
" TrialAsTask,\n",
" StratifiedStandardizeY,\n",
" TaskChoiceToIntTaskChoice,\n",
"]"
]
},
{
Expand Down

0 comments on commit be685a7

Please sign in to comment.