Skip to content

Commit

Permalink
change normalize input constructor to accomodate tl modelbridge (face…
Browse files Browse the repository at this point in the history
…book#3185)

Summary:

Removing dependency of normalize transform for TL config on the properties of the search space.

Reviewed By: saitcakmak

Differential Revision: D67228575
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Dec 19, 2024
1 parent 8e3b456 commit 7918f7b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,39 @@ def _input_transform_argparse_normalize(
A dictionary with input transform kwargs.
"""
input_transform_options = input_transform_options or {}
d = input_transform_options.get("d", len(dataset.feature_names))
bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer):
d = dataset.X.values.shape[-1]
d = input_transform_options.get("d", len(dataset.feature_names))
# having d set to None means that we use the search space digest to infer d
if "d" in input_transform_options and input_transform_options["d"] is None:
d = len(search_space_digest.feature_names)
input_transform_options["d"] = d
else:
if isinstance(dataset, RankingDataset) and isinstance(
dataset.X, SliceContainer
):
d = dataset.X.values.shape[-1]
input_transform_options.setdefault("d", d)

indices = list(range(d))
task_features = normalize_indices(search_space_digest.task_features, d=d)

for task_feature in sorted(task_features, reverse=True):
del indices[task_feature]

input_transform_options.setdefault("d", d)
# having indices set to None means that we don't remove task features
if ("indices" in input_transform_options) and (
input_transform_options["indices"] is None
):
input_transform_options["indices"] = indices
else:
task_features = normalize_indices(search_space_digest.task_features, d=d)
for task_feature in sorted(task_features, reverse=True):
del indices[task_feature]

if ("indices" in input_transform_options) or (len(indices) < d):
input_transform_options.setdefault("indices", indices)

bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

if (
("bounds" not in input_transform_options)
and (bounds.shape[-1] < d)
Expand Down
30 changes: 23 additions & 7 deletions ax/models/torch/tests/test_input_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def setUp(self) -> None:
super().setUp()
X = torch.randn((10, 4))
Y = torch.randn((10, 2))

self.dataset = SupervisedDataset(
X=X,
Y=Y,
Expand All @@ -46,10 +47,10 @@ def setUp(self) -> None:
self.search_space_digest = SearchSpaceDigest(
feature_names=["a", "b", "c"],
bounds=[(0.0, 1.0), (0, 2), (0, 4)],
ordinal_features=[1],
categorical_features=[2],
discrete_choices={1: [0, 1, 2], 2: [0, 0.25, 4.0]},
task_features=[3],
ordinal_features=[0],
categorical_features=[1],
discrete_choices={0: [0, 1, 2], 1: [0, 0.25, 4.0]},
task_features=[2],
fidelity_features=[0],
target_values={0: 1.0},
robust_digest=None,
Expand Down Expand Up @@ -110,7 +111,8 @@ def test_argparse_normalize(self) -> None:
)
)
)
self.assertEqual(input_transform_kwargs["d"], 4)
self.assertEqual(input_transform_kwargs["d"], 3)
self.assertEqual(input_transform_kwargs["indices"], [0, 1])

input_transform_kwargs = input_transform_argparse(
Normalize,
Expand All @@ -125,6 +127,7 @@ def test_argparse_normalize(self) -> None:
)

self.assertEqual(input_transform_kwargs["d"], 4)
self.assertEqual(input_transform_kwargs["indices"], [0, 1, 3])

self.assertTrue(
torch.all(
Expand Down Expand Up @@ -157,8 +160,21 @@ def test_argparse_normalize(self) -> None:
dataset=mtds,
search_space_digest=self.search_space_digest,
)
self.assertEqual(input_transform_kwargs["d"], 4)
self.assertEqual(input_transform_kwargs["indices"], [0, 1, 2])
self.assertEqual(input_transform_kwargs["d"], 3)
self.assertEqual(input_transform_kwargs["indices"], [0, 1])

input_transform_kwargs = input_transform_argparse(
Normalize,
dataset=self.dataset,
search_space_digest=self.search_space_digest,
input_transform_options={
"bounds": None,
},
)

self.assertEqual(input_transform_kwargs["d"], 3)
self.assertEqual(input_transform_kwargs["indices"], [0, 1])
self.assertTrue(input_transform_kwargs["bounds"] is None)

def test_argparse_warp(self) -> None:
self.search_space_digest.task_features = [0, 3]
Expand Down

0 comments on commit 7918f7b

Please sign in to comment.