Skip to content

Commit

Permalink
Correctly instantiate HSS if any parameter is_hierarchical (facebook#…
Browse files Browse the repository at this point in the history
…3199)

Summary:

As titled

Reviewed By: saitcakmak

Differential Revision: D67423037
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Dec 19, 2024
1 parent 7bb3fe4 commit e80239c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
11 changes: 9 additions & 2 deletions ax/preview/api/utils/instantiation/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
RangeParameter,
)
from ax.core.parameter_constraint import validate_constraint_parameters
from ax.core.search_space import SearchSpace
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.preview.api.configs import (
ChoiceParameterConfig,
Expand Down Expand Up @@ -121,7 +121,14 @@ def experiment_from_config(config: ExperimentConfig) -> Experiment:
]
)

search_space = SearchSpace(parameters=parameters, parameter_constraints=constraints)
if any(p.is_hierarchical for p in parameters):
search_space = HierarchicalSearchSpace(
parameters=parameters, parameter_constraints=constraints
)
else:
search_space = SearchSpace(
parameters=parameters, parameter_constraints=constraints
)

return Experiment(
search_space=search_space,
Expand Down
62 changes: 61 additions & 1 deletion ax/preview/api/utils/instantiation/tests/test_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
RangeParameter,
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.preview.api.configs import (
ChoiceParameterConfig,
Expand Down Expand Up @@ -257,6 +257,66 @@ def test_experiment_from_config(self) -> None:
),
)

root_parameter = ChoiceParameterConfig(
name="root_param",
parameter_type=ParameterType.STRING,
values=["left", "right"],
dependent_parameters={
"left": ["float_param"],
"right": ["int_param"],
},
)

hss_config = ExperimentConfig(
name="test_experiment",
parameters=[float_parameter, int_parameter, root_parameter],
parameter_constraints=["int_param <= float_param"],
description="test description",
owner="miles",
)

self.assertEqual(
experiment_from_config(config=hss_config),
Experiment(
search_space=HierarchicalSearchSpace(
parameters=[
RangeParameter(
name="float_param",
parameter_type=CoreParameterType.FLOAT,
lower=0,
upper=1,
),
RangeParameter(
name="int_param",
parameter_type=CoreParameterType.INT,
lower=0,
upper=1,
),
ChoiceParameter(
name="root_param",
parameter_type=CoreParameterType.STRING,
values=["left", "right"],
is_ordered=False,
sort_values=False,
dependents={
"left": ["float_param"],
"right": ["int_param"],
},
),
],
parameter_constraints=[
ParameterConstraint(
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
)
],
),
name="test_experiment",
description="test description",
properties={"owners": ["miles"]},
default_data_type=DataType.MAP_DATA,
),
)

def test_parameter_type_converter(self) -> None:
self.assertEqual(
_parameter_type_converter(parameter_type=ParameterType.BOOL),
Expand Down

0 comments on commit e80239c

Please sign in to comment.