From 46fef44a20cbc57ebff602a24c29764f293e4cb1 Mon Sep 17 00:00:00 2001 From: Ben Letham Date: Wed, 22 Nov 2023 07:18:11 -0800 Subject: [PATCH] Add label_dict to tile plots (#2007) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2007 Most of the plots now allow passing along a label_dict. The tile plots didn't have it, but now they do. This enables having a legible plot with long metric names. Reviewed By: Balandat, ItsMrLin Differential Revision: D51502544 fbshipit-source-id: 76f4d1de5a59fcb3cecf6329c20346f4f79951a3 --- ax/plot/scatter.py | 18 ++++++++++++++++-- ax/plot/tests/test_tile_fitted.py | 10 ++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index 11a0689cba5..ff559e43f8d 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -417,7 +417,11 @@ def plot_multiple_metrics( layout_offset_x = 0 rel = checked_cast_optional(bool, kwargs.get("rel")) if rel is not None: - warnings.warn("Use `rel_x` and `rel_y` instead of `rel`.", DeprecationWarning) + warnings.warn( + "Use `rel_x` and `rel_y` instead of `rel`.", + DeprecationWarning, + stacklevel=2, + ) rel_x = rel rel_y = rel traces = _multiple_metric_traces( @@ -1298,6 +1302,7 @@ def tile_fitted( fixed_features: Optional[ObservationFeatures] = None, data_selector: Optional[Callable[[Observation], bool]] = None, scalarized_metric_config: Optional[List[Dict[str, Any]]] = None, + label_dict: Optional[Dict[str, str]] = None, ) -> AxPlotConfig: """Tile version of fitted outcome plots. @@ -1318,19 +1323,24 @@ def tile_fitted( the name of the new scalarized metric, and the value is a dictionary mapping each metric to its weight. e.g. {"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}. + label_dict: A dictionary that maps the label to an alias to be used in the plot. """ metrics = metrics or list(model.metric_names) nrows = int(np.ceil(len(metrics) / 2)) ncols = min(len(metrics), 2) # make subplots (plot per row) + if label_dict is None: + subplot_titles = metrics + else: + subplot_titles = [label_dict.get(metric, metric) for metric in metrics] fig = subplots.make_subplots( rows=nrows, cols=ncols, print_grid=False, shared_xaxes=False, shared_yaxes=False, - subplot_titles=tuple(metrics), + subplot_titles=tuple(subplot_titles), horizontal_spacing=0.05, vertical_spacing=0.30 / nrows, ) @@ -1671,6 +1681,7 @@ def tile_observations( metrics: Optional[List[str]] = None, arm_names: Optional[List[str]] = None, arm_noun: str = "arm", + label_dict: Optional[Dict[str, str]] = None, ) -> AxPlotConfig: """ Tiled plot with all observed outcomes. @@ -1686,6 +1697,8 @@ def tile_observations( rel: Plot relative values, if experiment has status quo. metrics: Limit results to this set of metrics. arm_names: Limit results to this set of arms. + arm_noun: Noun to use instead of "arm". + label_dict: A dictionary that maps the label to an alias to be used in the plot. Returns: Plot config for the plot. """ @@ -1703,4 +1716,5 @@ def tile_observations( rel=rel and (experiment.status_quo is not None), metrics=metrics, arm_noun=arm_noun, + label_dict=label_dict, ) diff --git a/ax/plot/tests/test_tile_fitted.py b/ax/plot/tests/test_tile_fitted.py index a1e881081b3..fc3075ddec1 100644 --- a/ax/plot/tests/test_tile_fitted.py +++ b/ax/plot/tests/test_tile_fitted.py @@ -141,8 +141,18 @@ def test_TileObservations(self) -> None: ]: self.assertIn(key, config.data["layout"]) + self.assertEqual( + config.data["layout"]["annotations"][0]["text"], "ax_test_metric" + ) + # Data self.assertEqual(config.data["data"][0]["x"], ["0_1", "0_2"]) self.assertEqual(config.data["data"][0]["y"], [2.0, 2.25]) self.assertEqual(config.data["data"][0]["type"], "scatter") self.assertIn("Arm 0_1", config.data["data"][0]["text"][0]) + + label_dict = {"ax_test_metric": "mapped_name"} + config = tile_observations( + experiment=exp, arm_names=["0_1", "0_2"], rel=False, label_dict=label_dict + ) + self.assertEqual(config.data["layout"]["annotations"][0]["text"], "mapped_name")