Skip to content

Commit

Permalink
Add label_dict to tile plots
Browse files Browse the repository at this point in the history
Summary: Most of the plots now allow passing along a label_dict. The tile plots didn't have it, but now they do. This enables having a legible plot with long metric names.

Differential Revision: D51502544
  • Loading branch information
bletham authored and facebook-github-bot committed Nov 21, 2023
1 parent 1adaba5 commit 157063d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
18 changes: 16 additions & 2 deletions ax/plot/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,11 @@ def plot_multiple_metrics(
layout_offset_x = 0
rel = checked_cast_optional(bool, kwargs.get("rel"))
if rel is not None:
warnings.warn("Use `rel_x` and `rel_y` instead of `rel`.", DeprecationWarning)
warnings.warn(
"Use `rel_x` and `rel_y` instead of `rel`.",
DeprecationWarning,
stacklevel=2,
)
rel_x = rel
rel_y = rel
traces = _multiple_metric_traces(
Expand Down Expand Up @@ -1298,6 +1302,7 @@ def tile_fitted(
fixed_features: Optional[ObservationFeatures] = None,
data_selector: Optional[Callable[[Observation], bool]] = None,
scalarized_metric_config: Optional[List[Dict[str, Any]]] = None,
label_dict: Optional[Dict[str, str]] = None,
) -> AxPlotConfig:
"""Tile version of fitted outcome plots.
Expand All @@ -1318,19 +1323,24 @@ def tile_fitted(
the name of the new scalarized metric, and the value is a dictionary mapping
each metric to its weight. e.g.
{"name": "metric1:agg", "weight": {"metric1_c1": 0.5, "metric1_c2": 0.5}}.
label_dict: A dictionary that maps the label to an alias to be used in the plot.
"""
metrics = metrics or list(model.metric_names)
nrows = int(np.ceil(len(metrics) / 2))
ncols = min(len(metrics), 2)

# make subplots (plot per row)
if label_dict is None:
subplot_titles = metrics
else:
subplot_titles = [label_dict.get(metric, metric) for metric in metrics]
fig = subplots.make_subplots(
rows=nrows,
cols=ncols,
print_grid=False,
shared_xaxes=False,
shared_yaxes=False,
subplot_titles=tuple(metrics),
subplot_titles=tuple(subplot_titles),
horizontal_spacing=0.05,
vertical_spacing=0.30 / nrows,
)
Expand Down Expand Up @@ -1671,6 +1681,7 @@ def tile_observations(
metrics: Optional[List[str]] = None,
arm_names: Optional[List[str]] = None,
arm_noun: str = "arm",
label_dict: Optional[Dict[str, str]] = None,
) -> AxPlotConfig:
"""
Tiled plot with all observed outcomes.
Expand All @@ -1686,6 +1697,8 @@ def tile_observations(
rel: Plot relative values, if experiment has status quo.
metrics: Limit results to this set of metrics.
arm_names: Limit results to this set of arms.
arm_noun: Noun to use instead of "arm".
label_dict: A dictionary that maps the label to an alias to be used in the plot.
Returns: Plot config for the plot.
"""
Expand All @@ -1703,4 +1716,5 @@ def tile_observations(
rel=rel and (experiment.status_quo is not None),
metrics=metrics,
arm_noun=arm_noun,
label_dict=label_dict,
)
10 changes: 10 additions & 0 deletions ax/plot/tests/test_tile_fitted.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,18 @@ def test_TileObservations(self) -> None:
]:
self.assertIn(key, config.data["layout"])

self.assertEqual(
config.data["layout"]["annotations"][0]["text"], "ax_test_metric"
)

# Data
self.assertEqual(config.data["data"][0]["x"], ["0_1", "0_2"])
self.assertEqual(config.data["data"][0]["y"], [2.0, 2.25])
self.assertEqual(config.data["data"][0]["type"], "scatter")
self.assertIn("Arm 0_1", config.data["data"][0]["text"][0])

label_dict = {"ax_test_metric": "mapped_name"}
config = tile_observations(
experiment=exp, arm_names=["0_1", "0_2"], rel=False, label_dict=label_dict
)
self.assertEqual(config.data["layout"]["annotations"][0]["text"], "mapped_name")

0 comments on commit 157063d

Please sign in to comment.