diff --git a/ax/core/map_metric.py b/ax/core/map_metric.py index 4dc0a4b5ad6..a460b9de032 100644 --- a/ax/core/map_metric.py +++ b/ax/core/map_metric.py @@ -30,4 +30,4 @@ class MapMetric(Metric): """ data_constructor: type[MapData] = MapData - map_key_info: MapKeyInfo[float] = MapKeyInfo(key="steps", default_value=0.0) + map_key_info: MapKeyInfo[float] = MapKeyInfo(key="step", default_value=0.0) diff --git a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py index ce598d457b7..ee94c237af2 100644 --- a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py +++ b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py @@ -93,7 +93,7 @@ def test_Init(self) -> None: "parameters": {MapKeyToFloat.DEFAULT_MAP_KEY: {"log_scale": False}} }, ) - self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) + self.assertDictEqual(t.parameters, {"step": {"log_scale": False}}) self.assertEqual(len(t._parameter_list), 1)