Skip to content

Commit

Permalink
Replace slice.js and interact_slice.js with python and generic_plotly.js
Browse files Browse the repository at this point in the history
Reviewed By: lena-kashtelyan

Differential Revision: D16965821

fbshipit-source-id: 9e8e0f75e6fb789cddbbd1e15db2637e3fb4cf73
  • Loading branch information
2timesjay authored and facebook-github-bot committed Aug 26, 2019
1 parent bf84a8f commit 634374a
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 173 deletions.
147 changes: 147 additions & 0 deletions ax/plot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,150 @@ def relativize_data(

def rgb(arr: List[int]) -> str:
return "rgb({},{},{})".format(*arr)


def slice_config_to_trace(
arm_data,
arm_name_to_parameters,
f,
fit_data,
grid,
metric,
param,
rel,
setx,
sd,
is_log,
visible,
):
# format data
res = relativize_data(f, sd, rel, arm_data, metric)
f_final = res[0]
sd_final = res[1]

# get data for standard deviation fill plot
sd_upper = []
sd_lower = []
for i in range(len(sd)):
sd_upper.append(f_final[i] + 2 * sd_final[i])
sd_lower.append(f_final[i] - 2 * sd_final[i])
grid_rev = list(reversed(grid))
sd_lower_rev = list(reversed(sd_lower))
sd_x = grid + grid_rev
sd_y = sd_upper + sd_lower_rev

# get data for observed arms and error bars
arm_x = []
arm_y = []
arm_sem = []
for row in fit_data:
parameters = arm_name_to_parameters[row["arm_name"]]
plot = True
for p in setx.keys():
if p != param and parameters[p] != setx[p]:
plot = False
if plot:
arm_x.append(parameters[param])
arm_y.append(row["mean"])
arm_sem.append(row["sem"])

arm_res = relativize_data(arm_y, arm_sem, rel, arm_data, metric)
arm_y_final = arm_res[0]
arm_sem_final = [x * 2 for x in arm_res[1]]

# create traces
f_trace = {
"x": grid,
"y": f_final,
"showlegend": False,
"hoverinfo": "x+y",
"line": {"color": "rgba(128, 177, 211, 1)"},
"visible": visible,
}

arms_trace = {
"x": arm_x,
"y": arm_y_final,
"mode": "markers",
"error_y": {
"type": "data",
"array": arm_sem_final,
"visible": True,
"color": "black",
},
"line": {"color": "black"},
"showlegend": False,
"hoverinfo": "x+y",
"visible": visible,
}

sd_trace = {
"x": sd_x,
"y": sd_y,
"fill": "toself",
"fillcolor": "rgba(128, 177, 211, 0.2)",
"line": {"color": "transparent"},
"showlegend": False,
"hoverinfo": "none",
"visible": visible,
}

traces = [sd_trace, f_trace, arms_trace]

# iterate over out-of-sample arms
for i, generator_run_name in enumerate(arm_data["out_of_sample"].keys()):
ax = []
ay = []
asem = []
atext = []

for arm_name in arm_data["out_of_sample"][generator_run_name].keys():
parameters = arm_data["out_of_sample"][generator_run_name][arm_name][
"parameters"
]
plot = True
for p in setx.keys():
if p != param and parameters[p] != setx[p]:
plot = False
if plot:
ax.append(parameters[param])
ay.append(
arm_data["out_of_sample"][generator_run_name][arm_name]["y_hat"][
metric
]
)
asem.append(
arm_data["out_of_sample"][generator_run_name][arm_name]["se_hat"][
metric
]
)
atext.append("<em>Candidate " + arm_name + "</em>")

out_of_sample_arm_res = relativize_data(ay, asem, rel, arm_data, metric)
ay_final = out_of_sample_arm_res[0]
asem_final = [x * 2 for x in out_of_sample_arm_res[1]]

traces.append(
{
"hoverinfo": "text",
"legendgroup": generator_run_name,
"marker": {"color": "black", "symbol": i + 1, "opacity": 0.5},
"mode": "markers",
"error_y": {
"type": "data",
"array": asem_final,
"visible": True,
"color": "black",
},
"name": generator_run_name,
"text": atext,
"type": "scatter",
"xaxis": "x",
"x": ax,
"yaxis": "y",
"y": ay_final,
"visible": visible,
}
)

return traces
105 changes: 0 additions & 105 deletions ax/plot/js/interact_slice.js

This file was deleted.

61 changes: 0 additions & 61 deletions ax/plot/js/slice.js

This file was deleted.

6 changes: 1 addition & 5 deletions ax/plot/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ class _AxPlotJSResources(enum.Enum):


# JS-based plots that are supported in Ax should be registered here
Ax_PLOT_REGISTRY: Dict[enum.Enum, str] = {
AxPlotTypes.GENERIC: "generic_plotly.js",
AxPlotTypes.SLICE: "slice.js",
AxPlotTypes.INTERACT_SLICE: "interact_slice.js",
}
Ax_PLOT_REGISTRY: Dict[enum.Enum, str] = {AxPlotTypes.GENERIC: "generic_plotly.js"}


def _load_js_resource(resource_type: _AxPlotJSResources) -> str:
Expand Down
Loading

0 comments on commit 634374a

Please sign in to comment.