diff --git a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py index d6afa6ad0b8..0ab69a8a0cb 100644 --- a/ax/models/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/models/torch/botorch_modular/input_constructors/input_transforms.py @@ -113,27 +113,30 @@ def _input_transform_argparse_normalize( A dictionary with input transform kwargs. """ input_transform_options = input_transform_options or {} - d = input_transform_options.get("d", len(dataset.feature_names)) - bounds = torch.as_tensor( - search_space_digest.bounds, - dtype=torch_dtype, - device=torch_device, - ).T - if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer): - d = dataset.X.values.shape[-1] + d = input_transform_options.get("d", len(search_space_digest.feature_names)) + input_transform_options["d"] = d indices = list(range(d)) - task_features = normalize_indices(search_space_digest.task_features, d=d) - - for task_feature in sorted(task_features, reverse=True): - del indices[task_feature] - - input_transform_options.setdefault("d", d) + # having indices set to None means that we don't remove task features + if ("indices" in input_transform_options) and ( + input_transform_options["indices"] is None + ): + input_transform_options["indices"] = indices + else: + task_features = normalize_indices(search_space_digest.task_features, d=d) + for task_feature in sorted(task_features, reverse=True): + del indices[task_feature] if ("indices" in input_transform_options) or (len(indices) < d): input_transform_options.setdefault("indices", indices) + bounds = torch.as_tensor( + search_space_digest.bounds, + dtype=torch_dtype, + device=torch_device, + ).T + if ( ("bounds" not in input_transform_options) and (bounds.shape[-1] < d) diff --git a/ax/models/torch/tests/test_input_transform_argparse.py b/ax/models/torch/tests/test_input_transform_argparse.py index 69e4f1e93b6..5e33955c306 100644 --- a/ax/models/torch/tests/test_input_transform_argparse.py +++ b/ax/models/torch/tests/test_input_transform_argparse.py @@ -46,10 +46,10 @@ def setUp(self) -> None: self.search_space_digest = SearchSpaceDigest( feature_names=["a", "b", "c"], bounds=[(0.0, 1.0), (0, 2), (0, 4)], - ordinal_features=[1], - categorical_features=[2], - discrete_choices={1: [0, 1, 2], 2: [0, 0.25, 4.0]}, - task_features=[3], + ordinal_features=[0], + categorical_features=[1], + discrete_choices={0: [0, 1, 2], 1: [0, 0.25, 4.0]}, + task_features=[2], fidelity_features=[0], target_values={0: 1.0}, robust_digest=None, @@ -110,7 +110,8 @@ def test_argparse_normalize(self) -> None: ) ) ) - self.assertEqual(input_transform_kwargs["d"], 4) + self.assertEqual(input_transform_kwargs["d"], 3) + self.assertEqual(input_transform_kwargs["indices"], [0, 1]) input_transform_kwargs = input_transform_argparse( Normalize, @@ -125,6 +126,7 @@ def test_argparse_normalize(self) -> None: ) self.assertEqual(input_transform_kwargs["d"], 4) + self.assertEqual(input_transform_kwargs["indices"], [0, 1, 3]) self.assertTrue( torch.all( @@ -157,8 +159,21 @@ def test_argparse_normalize(self) -> None: dataset=mtds, search_space_digest=self.search_space_digest, ) - self.assertEqual(input_transform_kwargs["d"], 4) - self.assertEqual(input_transform_kwargs["indices"], [0, 1, 2]) + self.assertEqual(input_transform_kwargs["d"], 3) + self.assertEqual(input_transform_kwargs["indices"], [0, 1]) + + input_transform_kwargs = input_transform_argparse( + Normalize, + dataset=self.dataset, + search_space_digest=self.search_space_digest, + input_transform_options={ + "bounds": None, + }, + ) + + self.assertEqual(input_transform_kwargs["d"], 3) + self.assertEqual(input_transform_kwargs["indices"], [0, 1]) + self.assertTrue(input_transform_kwargs["bounds"] is None) def test_argparse_warp(self) -> None: self.search_space_digest.task_features = [0, 3]