From 461b04ef3d8a4a5a29dbecf6ab68b5379135bbe5 Mon Sep 17 00:00:00 2001 From: Jelena Markovic-Voronov Date: Mon, 16 Dec 2024 15:54:12 -0800 Subject: [PATCH] change normalize input constructor to accomodate tl modelbridge Summary: Removing dependency of normalize transform for TL config on the properties of the search space. Reviewed By: saitcakmak Differential Revision: D67228575 --- .../input_constructors/input_transforms.py | 23 +++++++------- .../tests/test_input_transform_argparse.py | 30 ++++++++++++++----- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py index d6afa6ad0b8..2aee5035581 100644 --- a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py @@ -113,26 +113,23 @@ def _input_transform_argparse_normalize( A dictionary with input transform kwargs. """ input_transform_options = input_transform_options or {} - d = input_transform_options.get("d", len(dataset.feature_names)) - bounds = torch.as_tensor( - search_space_digest.bounds, - dtype=torch_dtype, - device=torch_device, - ).T - if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer): - d = dataset.X.values.shape[-1] + d = input_transform_options.get("d", len(search_space_digest.feature_names)) + input_transform_options["d"] = d + # having indices set to None avoids removing the task features indices = list(range(d)) - task_features = normalize_indices(search_space_digest.task_features, d=d) + task_features = normalize_indices(search_space_digest.task_features, d=d) for task_feature in sorted(task_features, reverse=True): del indices[task_feature] + input_transform_options.setdefault("indices", indices) - input_transform_options.setdefault("d", d) - - if ("indices" in input_transform_options) or (len(indices) < d): - input_transform_options.setdefault("indices", indices) + bounds = torch.as_tensor( + search_space_digest.bounds, + dtype=torch_dtype, + device=torch_device, + ).T if ( ("bounds" not in input_transform_options) diff --git a/ax/models/torch/tests/test_input_transform_argparse.py b/ax/models/torch/tests/test_input_transform_argparse.py index 69e4f1e93b6..7b2f74282a7 100644 --- a/ax/models/torch/tests/test_input_transform_argparse.py +++ b/ax/models/torch/tests/test_input_transform_argparse.py @@ -37,6 +37,7 @@ def setUp(self) -> None: super().setUp() X = torch.randn((10, 4)) Y = torch.randn((10, 2)) + self.dataset = SupervisedDataset( X=X, Y=Y, @@ -46,10 +47,10 @@ def setUp(self) -> None: self.search_space_digest = SearchSpaceDigest( feature_names=["a", "b", "c"], bounds=[(0.0, 1.0), (0, 2), (0, 4)], - ordinal_features=[1], - categorical_features=[2], - discrete_choices={1: [0, 1, 2], 2: [0, 0.25, 4.0]}, - task_features=[3], + ordinal_features=[0], + categorical_features=[1], + discrete_choices={0: [0, 1, 2], 1: [0, 0.25, 4.0]}, + task_features=[2], fidelity_features=[0], target_values={0: 1.0}, robust_digest=None, @@ -110,7 +111,8 @@ def test_argparse_normalize(self) -> None: ) ) ) - self.assertEqual(input_transform_kwargs["d"], 4) + self.assertEqual(input_transform_kwargs["d"], 3) + self.assertEqual(input_transform_kwargs["indices"], [0, 1]) input_transform_kwargs = input_transform_argparse( Normalize, @@ -125,6 +127,7 @@ def test_argparse_normalize(self) -> None: ) self.assertEqual(input_transform_kwargs["d"], 4) + self.assertEqual(input_transform_kwargs["indices"], [0, 1, 3]) self.assertTrue( torch.all( @@ -157,8 +160,21 @@ def test_argparse_normalize(self) -> None: dataset=mtds, search_space_digest=self.search_space_digest, ) - self.assertEqual(input_transform_kwargs["d"], 4) - self.assertEqual(input_transform_kwargs["indices"], [0, 1, 2]) + self.assertEqual(input_transform_kwargs["d"], 3) + self.assertEqual(input_transform_kwargs["indices"], [0, 1]) + + input_transform_kwargs = input_transform_argparse( + Normalize, + dataset=self.dataset, + search_space_digest=self.search_space_digest, + input_transform_options={ + "bounds": None, + }, + ) + + self.assertEqual(input_transform_kwargs["d"], 3) + self.assertEqual(input_transform_kwargs["indices"], [0, 1]) + self.assertTrue(input_transform_kwargs["bounds"] is None) def test_argparse_warp(self) -> None: self.search_space_digest.task_features = [0, 3]