diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 5d8f8288a4b..5f657bfb61c 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -36,7 +36,6 @@ ChoiceToNumericChoice, OrderedChoiceToIntegerRange, ) -from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames from ax.modelbridge.transforms.derelativize import Derelativize from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice @@ -131,15 +130,6 @@ # call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`. TS_trans: list[type[Transform]] = Y_trans + [SearchSpaceToChoice] -# Multi-type MTGP transforms -MT_MTGP_trans: list[type[Transform]] = Cont_X_trans + [ - Derelativize, - ConvertMetricNames, - TrialAsTask, - StratifiedStandardizeY, - TaskChoiceToIntTaskChoice, -] - # Single-type MTGP transforms ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [ Derelativize, @@ -148,9 +138,9 @@ TaskChoiceToIntTaskChoice, ] -# Single-type MTGP transforms -Specified_Task_ST_MTGP_trans: list[type[Transform]] = Cont_X_trans + [ +MBM_MTGP_trans: list[type[Transform]] = MBM_X_trans + [ Derelativize, + TrialAsTask, StratifiedStandardizeY, TaskChoiceToIntTaskChoice, ] @@ -218,7 +208,7 @@ class ModelSetup(NamedTuple): "ST_MTGP": ModelSetup( bridge_class=TorchModelBridge, model_class=ModularBoTorchModel, - transforms=ST_MTGP_trans, + transforms=MBM_MTGP_trans, standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, ), "BO_MIXED": ModelSetup( @@ -241,7 +231,7 @@ class ModelSetup(NamedTuple): "SAAS_MTGP": ModelSetup( bridge_class=TorchModelBridge, model_class=ModularBoTorchModel, - transforms=ST_MTGP_trans, + transforms=MBM_MTGP_trans, default_model_kwargs={ "surrogate_spec": SurrogateSpec( botorch_model_class=SaasFullyBayesianMultiTaskGP diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 5fe3282ce41..cbf0ae8f920 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -43,9 +43,9 @@ from ax.modelbridge.registry import ( _extract_model_state_after_gen, Cont_X_trans, + MBM_MTGP_trans, MODEL_KEY_TO_MODEL_SETUP, Models, - ST_MTGP_trans, ) from ax.modelbridge.torch import TorchModelBridge from ax.modelbridge.transition_criterion import ( @@ -1106,7 +1106,7 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features( model_kwargs={ # this will cause an error if the model # doesn't get fixed features - "transforms": ST_MTGP_trans, + "transforms": MBM_MTGP_trans, **self.step_model_kwargs, }, num_trials=1, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index f43e8f14a9d..b9eb7e0be9e 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -44,7 +44,7 @@ from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models, ST_MTGP_trans +from ax.modelbridge.registry import MBM_MTGP_trans, Models from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin from ax.runners.synthetic import SyntheticRunner from ax.service.scheduler import ( @@ -2391,7 +2391,7 @@ def test_it_works_with_multitask_models( model_kwargs={ # this will cause and error if the model # doesn't get fixed features - "transforms": ST_MTGP_trans, + "transforms": MBM_MTGP_trans, "transform_configs": { "TrialAsTask": { "trial_level_map": { diff --git a/tutorials/multi_task.ipynb b/tutorials/multi_task.ipynb index 0efbc74c04c..053b7aa4c4a 100644 --- a/tutorials/multi_task.ipynb +++ b/tutorials/multi_task.ipynb @@ -56,15 +56,30 @@ "from ax.core.search_space import SearchSpace\n", "from ax.metrics.hartmann6 import Hartmann6Metric\n", "from ax.modelbridge.factory import get_sobol\n", - "from ax.modelbridge.registry import Models, MT_MTGP_trans, ST_MTGP_trans\n", + "from ax.modelbridge.registry import Models, MBM_X_trans, ST_MTGP_trans\n", "from ax.modelbridge.torch import TorchModelBridge\n", "from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment\n", + "from ax.modelbridge.transforms.derelativize import Derelativize\n", + "from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames\n", + "from ax.modelbridge.transforms.trial_as_task import TrialAsTask\n", + "from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY\n", + "from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice\n", "from ax.plot.diagnostic import interact_batch_comparison\n", "from ax.runners.synthetic import SyntheticRunner\n", "from ax.utils.common.typeutils import checked_cast\n", "from ax.utils.notebook.plotting import init_notebook_plotting, render\n", "\n", - "init_notebook_plotting()" + "init_notebook_plotting()\n", + "\n", + "# Transforms for pre-processing the data from a multi-type experiment to \n", + "# construct a multi-task GP model.\n", + "MT_MTGP_trans = MBM_X_trans + [\n", + " Derelativize,\n", + " ConvertMetricNames,\n", + " TrialAsTask,\n", + " StratifiedStandardizeY,\n", + " TaskChoiceToIntTaskChoice,\n", + "]" ] }, {