Skip to content

Commit

Permalink
Type annotations -- mark some parameters as immutable; copy immutable…
Browse files Browse the repository at this point in the history
… arguments that may be mutated (#3213)

Summary:
Pull Request resolved: #3213

Context:

I ran into some of these while adding a benchmark and decided to fix the annotations rather than adding pyre-ignores

This PR:
* Marks `BenchmarkProblem.status_quo_params` as immutable
* Marks `SurrogateTestFunction.outcome_names` as immutable
* Copies parameters when they are passed to `ObservationFeatures` so they will not be mutated by transforms
* Marks `parameters` passed to `Arm` as immutable
* Marks parameters as immutalbe in various search space functions
* Removes unused ignores

Reviewed By: Balandat

Differential Revision: D67624124

fbshipit-source-id: 14c7c6d39bb465fc43e0eaafd014d010e8fac4e8
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 24, 2024
1 parent c561e7a commit 6607582
Show file tree
Hide file tree
Showing 26 changed files with 23 additions and 85 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class BenchmarkProblem(Base):
n_best_points: int = 1
step_runtime_function: TBenchmarkStepRuntimeFunction | None = None
target_fidelity_and_task: Mapping[str, TParamValue] = field(default_factory=dict)
status_quo_params: TParameterization | None = None
status_quo_params: Mapping[str, TParamValue] | None = None
auxiliary_experiments_by_purpose: (
dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None
) = None
Expand Down
8 changes: 4 additions & 4 deletions ax/benchmark/benchmark_test_functions/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -38,7 +38,7 @@ class SurrogateTestFunction(BenchmarkTestFunction):
"""

name: str
outcome_names: list[str]
outcome_names: Sequence[str]
_surrogate: TorchModelBridge | None = None
get_surrogate: None | Callable[[], TorchModelBridge] = None

Expand All @@ -59,8 +59,8 @@ def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
observation_features=[ObservationFeatures(params)]
# `dict` makes a copy so that parameters are not mutated
observation_features=[ObservationFeatures(parameters=dict(params))]
)
means = [means[name][0] for name in self.outcome_names]
return torch.tensor(
Expand Down
5 changes: 0 additions & 5 deletions ax/benchmark/tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,6 @@ def test_with_learning_curve(self) -> None:
)

trial = Trial(experiment=experiment)
# pyre-fixme: Incompatible parameter type [6]: In call
# `Arm.__init__`, for argument `parameters`, expected `Dict[str,
# Union[None, bool, float, int, str]]` but got `Dict[str,
# float]`.
arm = Arm(name="0_0", parameters=params)
trial.add_arm(arm=arm)
metadata_dict = runner.run(trial=trial)
Expand All @@ -413,7 +409,6 @@ def test_with_learning_curve(self) -> None:
test_function=test_function, noise_std=0.0, max_concurrency=2
)

# pyre-fixme[6]: Incompatible parameter type (because argument is mutable)
arm = Arm(name="0_0", parameters=params)
trial = Trial(experiment=experiment)
trial.add_arm(arm=arm)
Expand Down
6 changes: 4 additions & 2 deletions ax/core/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ class Arm(SortableBase):
encapsulates the parametrization needed by the unit.
"""

def __init__(self, parameters: TParameterization, name: str | None = None) -> None:
def __init__(
self, parameters: Mapping[str, TParamValue], name: str | None = None
) -> None:
"""Inits Arm.
Args:
Expand Down Expand Up @@ -132,7 +134,7 @@ def _unique_id(self) -> str:


def _numpy_types_to_python_types(
parameterization: TParameterization,
parameterization: Mapping[str, TParamValue],
) -> TParameterization:
"""If applicable, coerce values of the parameterization from Numpy int/float to
Python int/float.
Expand Down
10 changes: 4 additions & 6 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def update_parameter(self, parameter: Parameter) -> None:
self._parameters[parameter.name] = parameter

def check_all_parameters_present(
self,
parameterization: TParameterization,
raise_error: bool = False,
self, parameterization: Mapping[str, TParamValue], raise_error: bool = False
) -> bool:
"""Whether a given parameterization contains all the parameters in the
search space.
Expand All @@ -204,7 +202,7 @@ def check_all_parameters_present(

def check_membership(
self,
parameterization: TParameterization,
parameterization: Mapping[str, TParamValue],
raise_error: bool = False,
check_all_parameters_present: bool = True,
) -> bool:
Expand Down Expand Up @@ -567,7 +565,7 @@ def flatten_observation_features(

def check_membership(
self,
parameterization: TParameterization,
parameterization: Mapping[str, TParamValue],
raise_error: bool = False,
check_all_parameters_present: bool = True,
) -> bool:
Expand Down Expand Up @@ -673,7 +671,7 @@ def _cast_arm(self, arm: Arm) -> Arm:

def _cast_parameterization(
self,
parameters: TParameterization,
parameters: Mapping[str, TParamValue],
check_all_parameters_present: bool = True,
) -> TParameterization:
"""Cast parameterization (of an arm, observation features, etc.) to the
Expand Down
6 changes: 0 additions & 6 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,6 @@ def test_NormalizedArmWeights(self) -> None:
{"w": 0.75, "x": 1, "y": "foo", "z": True},
{"w": 0.77, "x": 2, "y": "foo", "z": True},
]
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
arms = [Arm(parameters=p) for i, p in enumerate(parameterizations)]
new_batch_trial.add_arms_and_weights(arms=arms, weights=[2, 1])

Expand Down Expand Up @@ -592,8 +590,6 @@ def test_SetStatusQuoAndOptimizePower(self) -> None:
{"w": 0.75, "x": 1, "y": "foo", "z": True},
{"w": 0.77, "x": 2, "y": "foo", "z": True},
]
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
arms = [Arm(parameters=p) for i, p in enumerate(parameterizations)]
batch_trial.add_arms_and_weights(arms=arms)
batch_trial.set_status_quo_and_optimize_power(status_quo)
Expand All @@ -619,8 +615,6 @@ def test_SetStatusQuoAndOptimizePower(self) -> None:
{"w": 0.77, "x": 2, "y": "foo", "z": True},
{"w": 0.0, "x": 1, "y": "foo", "z": True},
]
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
arms = [Arm(parameters=p) for i, p in enumerate(parameterizations)]
batch_trial.add_arms_and_weights(arms=arms)
batch_trial.set_status_quo_and_optimize_power(status_quo)
Expand Down
4 changes: 2 additions & 2 deletions ax/core/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def test_MaxValuesValidation(self) -> None:
ChoiceParameter(
name="x",
parameter_type=ParameterType.INT,
values=list(range(999)), # pyre-ignore
values=list(range(999)),
)
with self.assertRaisesRegex(
UserInputError,
Expand All @@ -396,7 +396,7 @@ def test_MaxValuesValidation(self) -> None:
ChoiceParameter(
name="x",
parameter_type=ParameterType.INT,
values=list(range(1001)), # pyre-ignore
values=list(range(1001)),
)

def test_Hierarchical(self) -> None:
Expand Down
22 changes: 0 additions & 22 deletions ax/core/tests/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,48 +254,30 @@ def test_CheckMembership(self) -> None:
p_dict = {"a": 1.0, "b": 5, "c": "foo", "d": True, "e": 0.2, "f": 5}

# Valid
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
self.assertTrue(self.ss2.check_membership(p_dict))

# Value out of range
p_dict["a"] = 20.0
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
self.assertFalse(self.ss2.check_membership(p_dict))
with self.assertRaises(ValueError):
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool,
# float, int, str]]` but got `Dict[str, Union[float, str]]`.
self.ss2.check_membership(p_dict, raise_error=True)

# Violate constraints
p_dict["a"] = 5.3
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
self.assertFalse(self.ss2.check_membership(p_dict))
with self.assertRaises(ValueError):
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool,
# float, int, str]]` but got `Dict[str, Union[float, str]]`.
self.ss2.check_membership(p_dict, raise_error=True)

# Incomplete param dict
p_dict.pop("a")
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
self.assertFalse(self.ss2.check_membership(p_dict))
with self.assertRaises(ValueError):
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool,
# float, int, str]]` but got `Dict[str, Union[float, str]]`.
self.ss2.check_membership(p_dict, raise_error=True)

# Unknown parameter
p_dict["q"] = 40
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
self.assertFalse(self.ss2.check_membership(p_dict))
with self.assertRaises(ValueError):
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool,
# float, int, str]]` but got `Dict[str, Union[float, str]]`.
self.ss2.check_membership(p_dict, raise_error=True)

def test_CheckTypes(self) -> None:
Expand Down Expand Up @@ -335,15 +317,11 @@ def test_CastArm(self) -> None:

# Check "b" parameter goes from float to int
self.assertTrue(isinstance(p_dict["b"], float))
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
new_arm = self.ss2.cast_arm(Arm(p_dict))
self.assertTrue(isinstance(new_arm.parameters["b"], int))

# Unknown parameter should be unchanged
p_dict["q"] = 40
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, Union[float, str]]`.
new_arm = self.ss2.cast_arm(Arm(p_dict))
self.assertTrue(isinstance(new_arm.parameters["q"], int))

Expand Down
3 changes: 0 additions & 3 deletions ax/metrics/tests/test_chemistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ def test_ChemistryMetric(self) -> None:
params = dict(zip(param_names, param_values))
trial = get_trial()
trial._generator_run = GeneratorRun(
# pyre-fixme[6]: For 2nd argument expected `Dict[str,
# Union[None, bool, float, int, str]]` but got `Dict[str,
# Union[float, int, str]]`.
arms=[Arm(name="0_0", parameters=params)]
)
df = metric.fetch_trial_data(trial).unwrap().df
Expand Down
2 changes: 0 additions & 2 deletions ax/metrics/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def test_SklearnMetric(self) -> None:
params = {"max_depth": 2, "min_samples_split": 0.5}
trial = get_trial()
trial._generator_run = GeneratorRun(
# pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool,
# float, int, str]]` but got `Dict[str, float]`.
arms=[Arm(name="0_0", parameters=params)]
)
df = metric.fetch_trial_data(trial).unwrap().df
Expand Down
3 changes: 0 additions & 3 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,6 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen):
)
for i in range(5)
],
# pyre-fixme[6]: For 2nd param expected `List[float]` but got `List[int]`.
weights=[1] * 5,
)
exp = get_branin_experiment_with_multi_objective(with_status_quo=True)
Expand Down Expand Up @@ -724,8 +723,6 @@ def test_GenArms(self) -> None:
self.assertEqual(arms[0].parameters, p1)
self.assertIsNone(candidate_metadata)

# pyre-fixme[6]: For 2nd param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, int]`.
arm = Arm(name="1_1", parameters=p1)
arms_by_signature = {arm.signature: arm}
observation_features[0].metadata = {"some_key": "some_val_0"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
inspect.Parameter(
name="gs_gen_call_kwargs",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
# pyre-fixme[16]: `dict` has no attr. `__getitem__`
annotation=dict[str, Any],
),
inspect.Parameter(
Expand Down Expand Up @@ -656,8 +655,6 @@ def test_all_constructors_have_same_signature(self) -> None:
func_parameters["previous_node"], GenerationNode | None
)
self.assertEqual(func_parameters["next_node"], GenerationNode)
# pyre-ignore [16]: Undefined attribute [16]: `dict` has no attribute
# `__getitem__`.¸
self.assertEqual(func_parameters["gs_gen_call_kwargs"], dict[str, Any])
self.assertEqual(func_parameters["experiment"], Experiment)
self.assertEqual(method_signature, inspect.signature(constructor))
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_TorchModelBridge(
observations = recombine_observations(observation_features, observation_data)
ssd = SearchSpaceDigest(
feature_names=feature_names,
bounds=[(0, 1)] * 3, # pyre-ignore
bounds=[(0, 1)] * 3,
)

with mock.patch(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/transforms/task_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
transformed_parameters[p_name] = ChoiceParameter(
name=p_name,
parameter_type=ParameterType.INT,
values=list(range(len(p.values))), # pyre-ignore [6]
values=list(range(len(p.values))),
is_ordered=p.is_ordered,
is_task=True,
sort_values=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_num_choices(self) -> None:
new_search_space.parameters["d"],
ChoiceParameter(
"d",
values=list(range(1, 10)), # pyre-ignore
values=list(range(1, 10)),
is_ordered=True,
parameter_type=ParameterType.INT,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,6 @@ def test_RoundingWithConstrainedIntRanges(self) -> None:
RangeParameter("y", lower=1, upper=3, parameter_type=ParameterType.INT),
]
constrained_int_search_space = SearchSpace(
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
# `List[RangeParameter]`.
parameters=parameters,
parameter_constraints=[
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
Expand Down Expand Up @@ -244,8 +242,6 @@ def test_RoundingWithImpossiblyConstrainedIntRanges(self) -> None:
RangeParameter("y", lower=1, upper=5, parameter_type=ParameterType.INT),
]
constrained_int_search_space = SearchSpace(
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
# `List[RangeParameter]`.
parameters=parameters,
parameter_constraints=[
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
Expand Down
2 changes: 0 additions & 2 deletions ax/models/tests/test_randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def test_RFModel(self) -> None:
datasets=datasets,
search_space_digest=SearchSpaceDigest(
feature_names=["x1", "x2"],
# pyre-fixme[6]: For 2nd param expected `List[Tuple[Union[float,
# int], Union[float, int]]]` but got `List[Tuple[int, int]]`.
bounds=[(0, 1)] * 2,
),
)
Expand Down
2 changes: 0 additions & 2 deletions ax/preview/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,6 @@ def predict(
"""
for parameters in points:
self._experiment.search_space.check_membership(
# pyre-fixme[6]: Core Ax allows users to specify TParameterization
# values as None but we do not allow this in the API.
parameterization=parameters,
raise_error=True,
check_all_parameters_present=True,
Expand Down
8 changes: 4 additions & 4 deletions ax/preview/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def test_get_best_parameterization(self) -> None:
)
self.assertTrue(
client._experiment.search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
parameterization=parameters
)
)
self.assertEqual({*prediction.keys()}, {"foo"})
Expand All @@ -938,7 +938,7 @@ def test_get_best_parameterization(self) -> None:
)
self.assertTrue(
client._experiment.search_space.check_membership(
parameterization=parameters # pyre-fixme[6]
parameterization=parameters
)
)
self.assertEqual({*prediction.keys()}, {"foo"})
Expand Down Expand Up @@ -994,7 +994,7 @@ def test_get_pareto_frontier(self) -> None:
)
self.assertTrue(
client._experiment.search_space.check_membership(
parameterization=parameters # pyre-ignore[6]
parameterization=parameters
)
)
self.assertEqual({*prediction.keys()}, {"foo", "bar"})
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def test_get_pareto_frontier(self) -> None:
)
self.assertTrue(
client._experiment.search_space.check_membership(
parameterization=parameters # pyre-fixme[6]
parameterization=parameters
)
)
self.assertEqual({*prediction.keys()}, {"foo", "bar"})
Expand Down
2 changes: 0 additions & 2 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,6 @@ def test_run_preattached_trials_only(self) -> None:
)
trial = scheduler.experiment.new_trial()
parameter_dict = {"x1": 5, "x2": 5}
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, int]`.
trial.add_arm(Arm(parameters=parameter_dict))

# check no new trials are run, when max_trials = 0
Expand Down
2 changes: 0 additions & 2 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,6 @@ def test_set_status_quo(self) -> None:
ax_client.set_status_quo(status_quo_params)
self.assertEqual(
ax_client.experiment.status_quo,
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool,
# float, int, str]]` but got `Dict[str, float]`.
Arm(parameters=status_quo_params, name="status_quo"),
)

Expand Down
Loading

0 comments on commit 6607582

Please sign in to comment.