From 634374aae25d5d660d982f82c450071ce9f4c42e Mon Sep 17 00:00:00 2001 From: "Jacob E. Jensen" Date: Mon, 26 Aug 2019 10:56:18 -0700 Subject: [PATCH] Replace slice.js and interact_slice.js with python and generic_plotly.js Reviewed By: lena-kashtelyan Differential Revision: D16965821 fbshipit-source-id: 9e8e0f75e6fb789cddbbd1e15db2637e3fb4cf73 --- ax/plot/helper.py | 147 ++++++++++++++++++++++++++++++++++ ax/plot/js/interact_slice.js | 105 ------------------------ ax/plot/js/slice.js | 61 -------------- ax/plot/render.py | 6 +- ax/plot/slice.py | 151 ++++++++++++++++++++++++++++++++++- 5 files changed, 297 insertions(+), 173 deletions(-) delete mode 100644 ax/plot/js/interact_slice.js delete mode 100644 ax/plot/js/slice.js diff --git a/ax/plot/helper.py b/ax/plot/helper.py index cc0857ea87a..99eb3c97b9e 100644 --- a/ax/plot/helper.py +++ b/ax/plot/helper.py @@ -647,3 +647,150 @@ def relativize_data( def rgb(arr: List[int]) -> str: return "rgb({},{},{})".format(*arr) + + +def slice_config_to_trace( + arm_data, + arm_name_to_parameters, + f, + fit_data, + grid, + metric, + param, + rel, + setx, + sd, + is_log, + visible, +): + # format data + res = relativize_data(f, sd, rel, arm_data, metric) + f_final = res[0] + sd_final = res[1] + + # get data for standard deviation fill plot + sd_upper = [] + sd_lower = [] + for i in range(len(sd)): + sd_upper.append(f_final[i] + 2 * sd_final[i]) + sd_lower.append(f_final[i] - 2 * sd_final[i]) + grid_rev = list(reversed(grid)) + sd_lower_rev = list(reversed(sd_lower)) + sd_x = grid + grid_rev + sd_y = sd_upper + sd_lower_rev + + # get data for observed arms and error bars + arm_x = [] + arm_y = [] + arm_sem = [] + for row in fit_data: + parameters = arm_name_to_parameters[row["arm_name"]] + plot = True + for p in setx.keys(): + if p != param and parameters[p] != setx[p]: + plot = False + if plot: + arm_x.append(parameters[param]) + arm_y.append(row["mean"]) + arm_sem.append(row["sem"]) + + arm_res = relativize_data(arm_y, arm_sem, rel, arm_data, metric) + arm_y_final = arm_res[0] + arm_sem_final = [x * 2 for x in arm_res[1]] + + # create traces + f_trace = { + "x": grid, + "y": f_final, + "showlegend": False, + "hoverinfo": "x+y", + "line": {"color": "rgba(128, 177, 211, 1)"}, + "visible": visible, + } + + arms_trace = { + "x": arm_x, + "y": arm_y_final, + "mode": "markers", + "error_y": { + "type": "data", + "array": arm_sem_final, + "visible": True, + "color": "black", + }, + "line": {"color": "black"}, + "showlegend": False, + "hoverinfo": "x+y", + "visible": visible, + } + + sd_trace = { + "x": sd_x, + "y": sd_y, + "fill": "toself", + "fillcolor": "rgba(128, 177, 211, 0.2)", + "line": {"color": "transparent"}, + "showlegend": False, + "hoverinfo": "none", + "visible": visible, + } + + traces = [sd_trace, f_trace, arms_trace] + + # iterate over out-of-sample arms + for i, generator_run_name in enumerate(arm_data["out_of_sample"].keys()): + ax = [] + ay = [] + asem = [] + atext = [] + + for arm_name in arm_data["out_of_sample"][generator_run_name].keys(): + parameters = arm_data["out_of_sample"][generator_run_name][arm_name][ + "parameters" + ] + plot = True + for p in setx.keys(): + if p != param and parameters[p] != setx[p]: + plot = False + if plot: + ax.append(parameters[param]) + ay.append( + arm_data["out_of_sample"][generator_run_name][arm_name]["y_hat"][ + metric + ] + ) + asem.append( + arm_data["out_of_sample"][generator_run_name][arm_name]["se_hat"][ + metric + ] + ) + atext.append("Candidate " + arm_name + "") + + out_of_sample_arm_res = relativize_data(ay, asem, rel, arm_data, metric) + ay_final = out_of_sample_arm_res[0] + asem_final = [x * 2 for x in out_of_sample_arm_res[1]] + + traces.append( + { + "hoverinfo": "text", + "legendgroup": generator_run_name, + "marker": {"color": "black", "symbol": i + 1, "opacity": 0.5}, + "mode": "markers", + "error_y": { + "type": "data", + "array": asem_final, + "visible": True, + "color": "black", + }, + "name": generator_run_name, + "text": atext, + "type": "scatter", + "xaxis": "x", + "x": ax, + "yaxis": "y", + "y": ay_final, + "visible": visible, + } + ) + + return traces diff --git a/ax/plot/js/interact_slice.js b/ax/plot/js/interact_slice.js deleted file mode 100644 index 6d208c2ad3c..00000000000 --- a/ax/plot/js/interact_slice.js +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - */ - -const arm_data = {{arm_data}}; -const arm_name_to_parameters = {{arm_name_to_parameters}}; -const f = {{f}}; -const fit_data = {{fit_data}}; -const grid = {{grid}}; -const metrics = {{metrics}} -const param = {{param}}; -const rel = {{rel}}; -const setx = {{setx}}; -const sd = {{sd}}; -const is_log = {{is_log}}; - -traces = []; -metric_cnt = metrics.length; - -for (let i = 0; i < metric_cnt; i++) { - cur_visible = i == 0; - metric = metrics[i]; - traces = traces.concat( - slice_config_to_trace( - arm_data[metric], - arm_name_to_parameters[metric], - f[metric], - fit_data[metric], - grid, - metric, - param, - rel, - setx, - sd[metric], - is_log[metric], - cur_visible, - ), - ); -} - -// layout -const xrange = axis_range(grid, is_log[metrics[0]]); -const xtype = is_log[metrics[0]] ? 'log' : 'linear'; - -let buttons = []; -for (let i = 0; i < metric_cnt; i++) { - metric = metrics[i]; - let trace_cnt = 3 + Object.keys(arm_data[metric]['out_of_sample']).length * 2; - visible = new Array(metric_cnt * trace_cnt); - visible.fill(false).fill(true, i * trace_cnt, (i + 1) * trace_cnt); - buttons.push({ - method: 'update', - args: [{visible: visible}, {'yaxis.title': metric}], - label: metric, - }); -} - -const layout = { - title: 'Predictions for a 1-d slice of the parameter space', - annotations: [ - { - showarrow: false, - text: 'Choose metric:', - x: 0.225, - xanchor: 'center', - xref: 'paper', - y: 1.005, - yanchor: 'bottom', - yref: 'paper', - }, - ], - updatemenus: [ - { - y: 1.1, - x: 0.5, - yanchor: 'top', - buttons: buttons, - }, - ], - hovermode: 'closest', - xaxis: { - anchor: 'y', - autorange: false, - exponentformat: 'e', - range: xrange, - tickfont: {size: 11}, - tickmode: 'auto', - title: param, - type: xtype, - }, - yaxis: { - autorange: true, - anchor: 'x', - tickfont: {size: 11}, - tickmode: 'auto', - title: metrics[0], - }, -}; - -Plotly.newPlot( - {{id}}, - traces, - layout, - {showLink: false}, -); diff --git a/ax/plot/js/slice.js b/ax/plot/js/slice.js deleted file mode 100644 index e56b62b176e..00000000000 --- a/ax/plot/js/slice.js +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - */ - -const arm_data = {{arm_data}}; -const arm_name_to_parameters = {{arm_name_to_parameters}}; -const f = {{f}}; -const fit_data = {{fit_data}}; -const grid = {{grid}}; -const metric = {{metric}}; -const param = {{param}}; -const rel = {{rel}}; -const setx = {{setx}}; -const sd = {{sd}}; -const is_log = {{is_log}}; - -traces = slice_config_to_trace( - arm_data, - arm_name_to_parameters, - f, - fit_data, - grid, - metric, - param, - rel, - setx, - sd, - is_log, - true, -); - -// layout -const xrange = axis_range(grid, is_log); -const xtype = is_log ? 'log' : 'linear'; - -layout = { - hovermode: 'closest', - xaxis: { - anchor: 'y', - autorange: false, - exponentformat: 'e', - range: xrange, - tickfont: {size: 11}, - tickmode: 'auto', - title: param, - type: xtype, - }, - yaxis: { - anchor: 'x', - tickfont: {size: 11}, - tickmode: 'auto', - title: metric, - }, -}; - -Plotly.newPlot( - {{id}}, - traces, - layout, - {showLink: false}, -); diff --git a/ax/plot/render.py b/ax/plot/render.py index 9c879cdbdfd..b1e41ce5768 100644 --- a/ax/plot/render.py +++ b/ax/plot/render.py @@ -32,11 +32,7 @@ class _AxPlotJSResources(enum.Enum): # JS-based plots that are supported in Ax should be registered here -Ax_PLOT_REGISTRY: Dict[enum.Enum, str] = { - AxPlotTypes.GENERIC: "generic_plotly.js", - AxPlotTypes.SLICE: "slice.js", - AxPlotTypes.INTERACT_SLICE: "interact_slice.js", -} +Ax_PLOT_REGISTRY: Dict[enum.Enum, str] = {AxPlotTypes.GENERIC: "generic_plotly.js"} def _load_js_resource(resource_type: _AxPlotJSResources) -> str: diff --git a/ax/plot/slice.py b/ax/plot/slice.py index 46d35499bb8..33f757c61a0 100644 --- a/ax/plot/slice.py +++ b/ax/plot/slice.py @@ -10,11 +10,14 @@ from ax.plot.base import AxPlotConfig, AxPlotTypes, PlotData from ax.plot.helper import ( TNullableGeneratorRunsDict, + axis_range, get_fixed_values, get_grid_for_parameter, get_plot_data, get_range_parameter, + slice_config_to_trace, ) +from plotly import graph_objs as go # type aliases @@ -158,7 +161,61 @@ def plot_slice( "sd": sd_plt, "is_log": ls, } - return AxPlotConfig(config, plot_type=AxPlotTypes.SLICE) + config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data + + arm_data = config["arm_data"] + arm_name_to_parameters = config["arm_name_to_parameters"] + f = config["f"] + fit_data = config["fit_data"] + grid = config["grid"] + metric = config["metric"] + param = config["param"] + rel = config["rel"] + setx = config["setx"] + sd = config["sd"] + is_log = config["is_log"] + + traces = slice_config_to_trace( + arm_data, + arm_name_to_parameters, + f, + fit_data, + grid, + metric, + param, + rel, + setx, + sd, + is_log, + True, + ) + + # layout + xrange = axis_range(grid, is_log) + xtype = "log" if is_log else "linear" + + layout = { + "hovermode": "closest", + "xaxis": { + "anchor": "y", + "autorange": False, + "exponentformat": "e", + "range": xrange, + "tickfont": {"size": 11}, + "tickmode": "auto", + "title": param, + "type": xtype, + }, + "yaxis": { + "anchor": "x", + "tickfont": {"size": 11}, + "tickmode": "auto", + "title": metric, + }, + } + + fig = go.Figure(data=traces, layout=layout) # pyre-ignore[16] + return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC) def interact_slice( @@ -253,4 +310,94 @@ def interact_slice( "sd": sd_plt_dict, "is_log": is_log_dict, } - return AxPlotConfig(config, plot_type=AxPlotTypes.INTERACT_SLICE) + config = AxPlotConfig(config, plot_type=AxPlotTypes.GENERIC).data + + arm_data = config["arm_data"] + arm_name_to_parameters = config["arm_name_to_parameters"] + f = config["f"] + fit_data = config["fit_data"] + grid = config["grid"] + metrics = config["metrics"] + param = config["param"] + rel = config["rel"] + setx = config["setx"] + sd = config["sd"] + is_log = config["is_log"] + + traces = [] + + for i, metric in enumerate(metrics): + cur_visible = i == 0 + metric = metrics[i] + traces.extend( + slice_config_to_trace( + arm_data[metric], + arm_name_to_parameters[metric], + f[metric], + fit_data[metric], + grid, + metric, + param, + rel, + setx, + sd[metric], + is_log[metric], + cur_visible, + ) + ) + + # layout + xrange = axis_range(grid, is_log[metrics[0]]) + xtype = "log" if is_log[metrics[0]] else "linear" + + buttons = [] + for i, metric in enumerate(metrics): + trace_cnt = 3 + len(arm_data[metric]["out_of_sample"].keys()) * 2 + visible = [False] * (len(metrics) * trace_cnt) + for j in range(i * trace_cnt, (i + 1) * trace_cnt): + visible[j] = True + buttons.append( + { + "method": "update", + "args": [{"visible": visible}, {"yaxis.title": metric}], + "label": metric, + } + ) + + layout = { + "title": "Predictions for a 1-d slice of the parameter space", + "annotations": [ + { + "showarrow": False, + "text": "Choose metric:", + "x": 0.225, + "xanchor": "center", + "xref": "paper", + "y": 1.005, + "yanchor": "bottom", + "yref": "paper", + } + ], + "updatemenus": [{"y": 1.1, "x": 0.5, "yanchor": "top", "buttons": buttons}], + "hovermode": "closest", + "xaxis": { + "anchor": "y", + "autorange": False, + "exponentformat": "e", + "range": xrange, + "tickfont": {"size": 11}, + "tickmode": "auto", + "title": param, + "type": xtype, + }, + "yaxis": { + "anchor": "x", + "autorange": True, + "tickfont": {"size": 11}, + "tickmode": "auto", + "title": metrics[0], + }, + } + + fig = go.Figure(data=traces, layout=layout) # pyre-ignore[16] + return AxPlotConfig(data=fig, plot_type=AxPlotTypes.GENERIC)