diff --git a/DESCRIPTION b/DESCRIPTION index 86599b2..aa6c988 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: modelStudio Title: Interactive Studio for Explanatory Model Analysis -Version: 3.0.0.9000 +Version: 3.1.0 Authors@R: c(person("Hubert", "Baniecki", role = c("aut", "cre"), email = "hbaniecki@gmail.com", diff --git a/NEWS.md b/NEWS.md index 4305d29..165e954 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ -# modelStudio (development) +# modelStudio 3.1.0 +* changed y-axis variable labels in `SV` to the same as in `BD` * added new parameter to `modelStudio()`: `max_features_fi = max_features`, which allows displaying a distinctive number of features in `FI` plot (other than in `BD` and `SV`) +* added new options to `ms_options()`: `**_axis_title`, which allow changing plot-specific axis title (default varies) # modelStudio 3.0.0 * **BREAKING CHANGES**: diff --git a/R/modelStudio.R b/R/modelStudio.R index c82aa55..797bc07 100644 --- a/R/modelStudio.R +++ b/R/modelStudio.R @@ -350,9 +350,8 @@ modelStudio.explainer <- function(explainer, variable_splits_type=variable_splits_type), "ingredients::accumulated_dependence (categorical)", show_info, pb, N/30) } - - fi_data <- prepare_feature_importance(fi, max_features_fi, options$show_boxplot, - attr(loss_function, "loss_name"), ...) + + fi_data <- prepare_feature_importance(fi, max_features_fi, options$show_boxplot, ...) pd_data <- prepare_partial_dependence(pd_n, pd_c, variables = variable_names) ad_data <- prepare_accumulated_dependence(ad_n, ad_c, variables = variable_names) mp_ret <- calculate( @@ -495,6 +494,8 @@ modelStudio.explainer <- function(explainer, if (is.null(options$ms_title)) options$ms_title <- paste0("Interactive Studio for ", label, " Model") if (!is.null(options$ms_subtitle)) options$ms_margin_top <- options$ms_margin_top + 40 if (is.null(options$margin_left)) options$margin_left <- max(105, 7*max(nchar(variable_names))) + if (is.null(options$fi_axis_title)) options$fi_axis_title <- + ifelse(is.null(attr(loss_function, "loss_name")), "drop-out loss", attr(loss_function, "loss_name")) options <- c(list(time = time, model_name = label, diff --git a/R/ms_options.R b/R/ms_options.R index 1644af5..9880f94 100644 --- a/R/ms_options.R +++ b/R/ms_options.R @@ -33,18 +33,19 @@ #' \item{default_color}{\code{#371ea3} for Break Down bar and highlighted line.} #' } #' } -#' \subsection{Plot specific options:}{ +#' \subsection{Plot-specific options:}{ #' \code{**} is a two letter code unique to each plot, might be #' one of \code{[bd,sv,cp,fi,pd,ad,rv,fd,tv,at]}.\cr #' #' \describe{ -#' \item{**_title}{Plot specific title. Default varies.} -#' \item{**_subtitle}{Plot specific subtitle. Default is \code{subtitle}.} -#' \item{**_bar_width}{Plot specific width of bars. Default is \code{bar_width}, +#' \item{**_title}{Plot-specific title. Default varies.} +#' \item{**_subtitle}{Plot-specific subtitle. Default is \code{subtitle}.} +#' \item{**_axis_title}{Plot-specific axis title. Default varies.} +#' \item{**_bar_width}{Plot-specific width of bars. Default is \code{bar_width}, #' ignored when \code{scale_plot = TRUE}.} -#' \item{**_line_size}{\code{line_size} Plot specific width of lines. Default is \code{line_size}.} -#' \item{**_point_size}{Plot specific point radius. Default is \code{point_size}.} -#' \item{**_*_color}{Plot specific \code{[bar,line,point]} color. Default is \code{[bar,line,point]_color}.} +#' \item{**_line_size}{Plot-specific width of lines. Default is \code{line_size}.} +#' \item{**_point_size}{Plot-specific point radius. Default is \code{point_size}.} +#' \item{**_*_color}{Plot-specific \code{[bar,line,point]} color. Default is \code{[bar,line,point]_color}.} #' } #' } #' @@ -131,18 +132,21 @@ ms_options <- function(...) { default_color = "#371ea3", bd_title = "Break Down", bd_subtitle = NULL, + bd_axis_title = "contribution", bd_bar_width = NULL, bd_positive_color = NULL, bd_negative_color = NULL, bd_default_color = NULL, sv_title = "Shapley Values", sv_subtitle = NULL, + sv_axis_title = "contribution", sv_bar_width = NULL, sv_positive_color = NULL, sv_negative_color = NULL, sv_default_color = NULL, cp_title = "Ceteris Paribus", cp_subtitle = NULL, + cp_axis_title = "prediction", cp_bar_width = NULL, cp_line_size = NULL, cp_point_size = 3, @@ -151,34 +155,41 @@ ms_options <- function(...) { cp_point_color = "#371ea3", fi_title = "Feature Importance", fi_subtitle = NULL, + fi_axis_title = NULL, fi_bar_width = NULL, fi_bar_color = NULL, pd_title = "Partial Dependence", pd_subtitle = NULL, + pd_axis_title = "average prediction", pd_bar_width = NULL, pd_line_size = NULL, pd_bar_color = NULL, pd_line_color = NULL, ad_title = "Accumulated Dependence", ad_subtitle = NULL, + ad_axis_title = "accumulated prediction", ad_bar_width = NULL, ad_line_size = NULL, ad_bar_color = NULL, ad_line_color = NULL, rv_title = "Residuals vs Feature", rv_subtitle = NULL, + rv_axis_title = "residuals", rv_point_size = NULL, rv_point_color = NULL, fd_title = "Feature Distribution", fd_subtitle = NULL, + fd_axis_title = "count", fd_bar_width = NULL, fd_bar_color = NULL, tv_title = "Target vs Feature", tv_subtitle = NULL, + tv_axis_title = "target", tv_point_size = NULL, tv_point_color = NULL, at_title = "Average Target vs Feature", at_subtitle = NULL, + at_axis_title = "average target", at_bar_width = NULL, at_line_size = NULL, at_point_size = 3, diff --git a/R/prepare.R b/R/prepare.R index eb7da1b..1669f30 100644 --- a/R/prepare.R +++ b/R/prepare.R @@ -286,7 +286,6 @@ prepare_ceteris_paribus <- function(x, variables = NULL) { } prepare_feature_importance <- function(x, max_features = 10, show_boxplot = TRUE, - x_title = NULL, margin = 0.15, digits = 3, rounding_function = round, ...) { ### Return the object for the FeatureImportance plot ### @@ -357,7 +356,6 @@ prepare_feature_importance <- function(x, max_features = 10, show_boxplot = TRUE ret$x <- new_x ret$m <- m ret$x_min_max <- min_max - ret$x_title <- ifelse(is.null(x_title), "drop-out loss", x_title) ret$desc <- data.frame(type = "desc", text = gsub("\n","
", desc)) diff --git a/inst/d3js/generatePlots.js b/inst/d3js/generatePlots.js index 580c767..20bbef3 100644 --- a/inst/d3js/generatePlots.js +++ b/inst/d3js/generatePlots.js @@ -137,7 +137,7 @@ function breakDown() { (margin.top + bdPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("contribution"); + .text(bdAxisTitle); var xAxis = d3.axisBottom(x) .ticks(5) @@ -375,7 +375,7 @@ function shapleyValues() { (margin.top + svPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("contribution"); + .text(svAxisTitle); var xAxis = d3.axisBottom(x) .ticks(5) @@ -390,7 +390,7 @@ function shapleyValues() { var y = d3.scaleBand() .rangeRound([margin.top - additionalHeight, margin.top + svPlotHeight]) .padding(0.33) - .domain(bData.map(d => d.variable_name)); + .domain(bData.map(d => d.variable)); var xGrid = SV.append("g") .attr("class", "grid") @@ -464,7 +464,7 @@ function shapleyValues() { .attr("fill-opacity", d => x(d.barSupport) - x(d.barStart) < 1.5 ? 0 : 1) //invisible bar for clicking purpose .attr("x", d => d.contribution > 0 ? x(d.barStart) : x(d.barSupport)) - .attr("y", d => y(d.variable_name)) + .attr("y", d => y(d.variable)) .attr("height", y.bandwidth()) .on('mouseover', tooltip.show) .on('mouseout', tooltip.hide) @@ -492,7 +492,7 @@ function shapleyValues() { .attr("class", "axisLabel") .attr("x", d => d.contribution > 0 ? x(d.barStart) - 5 : x(d.barSupport) + 5) - .attr("y", d => y(d.variable_name) + y.bandwidth()/2) + .attr("y", d => y(d.variable) + y.bandwidth()/2) .attr("dy", "0.4em") .attr("text-anchor", d => d.sign == "1" ? "end" : "start") .transition() @@ -509,15 +509,15 @@ function shapleyValues() { lines.append("line") .attr("class", "interceptLine") .attr("x1", d => d.contribution < 0 ? x(d.barSupport) : x(d.barStart)) - .attr("y1", d => y(d.variable_name)) + .attr("y1", d => y(d.variable)) .attr("x2", d => d.contribution < 0 ? x(d.barSupport) : x(d.barStart)) - .attr("y2", d => y(d.variable_name)) + .attr("y2", d => y(d.variable)) .transition() .duration(TIME) .delay((d,i) => (i+1) * TIME) .attr("y2", (d,i) => i == svBarCount - 1 - ? y(d.variable_name) + y.bandwidth() - : y(d.variable_name) + y.bandwidth()*2.5); + ? y(d.variable) + y.bandwidth() + : y(d.variable) + y.bandwidth()*2.5); // boxplots @@ -527,8 +527,8 @@ function shapleyValues() { .attr("class", "interceptLine") .attr("x1", d => d.contribution < 0 ? x(d.max) : x(d.min)) .attr("x2", d => d.contribution < 0 ? x(d.max) : x(d.min)) - .attr("y1", d => y(d.variable_name) + y.bandwidth()/2) - .attr("y2", d => y(d.variable_name) + y.bandwidth()/2) + .attr("y1", d => y(d.variable) + y.bandwidth()/2) + .attr("y2", d => y(d.variable) + y.bandwidth()/2) .transition() .duration(TIME) .delay((d,i) => i * TIME) @@ -538,7 +538,7 @@ function shapleyValues() { // rectangle for the main box bars.append("rect") .attr("x", d => d.contribution < 0 ? x(d.q3) : x(d.q1)) - .attr("y", d => y(d.variable_name) + y.bandwidth()/3) + .attr("y", d => y(d.variable) + y.bandwidth()/3) .attr("height", y.bandwidth()/3) .style("fill", "#371ea3") .transition() @@ -550,8 +550,8 @@ function shapleyValues() { // // show the median // bars.append("line") // .attr("class", "interceptLine") - // .attr("y1", d => y(d.variable_name) + y.bandwidth()/3) - // .attr("y2", d => y(d.variable_name) + 2*y.bandwidth()/3) + // .attr("y1", d => y(d.variable) + y.bandwidth()/3) + // .attr("y2", d => y(d.variable) + 2*y.bandwidth()/3) // .attr("x1", d => x(d.median)) // .attr("x2", d => x(d.median)) // .style("stroke", "#ceced9") @@ -626,7 +626,6 @@ function featureImportance() { var fiBarCount = fiData.m, bData = fiData.x, xMinMax = fiData.x_min_max, - xTitle = fiData.x_title, desc = fiData.desc; var fiPlotHeight = SCALE_PLOT ? h : fiBarCount*fiBarWidth + (fiBarCount+1)*fiBarWidth/2, @@ -642,7 +641,7 @@ function featureImportance() { (margin.top + fiPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text(xTitle); + .text(fiAxisTitle); var xAxis = d3.axisBottom(x) .ticks(5) @@ -1165,7 +1164,7 @@ function cpNumericalPlot(variableName, lData, mData, yMinMax, pData, desc) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + cpPlotHeight/2)) .attr("text-anchor", "middle") - .text("prediction"); + .text(cpAxisTitle); var description = CP.append("g") .attr("transform", "translate(" + @@ -1321,8 +1320,16 @@ function cpCategoricalPlot(variableName, bData, yMinMax, lData, desc) { (margin.top + cpPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("prediction"); + .text(cpAxisTitle); +/* CP.append("text") + .attr("class", "axisTitle") + .attr("y", margin.top - additionalHeight) + .attr("x", margin.small) + .attr("text-anchor", "start") + .attr("dominant-baseline", "hanging") + .text(variableName + " = " + lData[0][variableName]);*/ + var description = CP.append("g") .attr("transform", "translate(" + (margin.left + cpPlotWidth - 4*margin.big - margin.small) @@ -1496,7 +1503,7 @@ function pdNumericalPlot(variableName, lData, mData, yMinMax, yMean, desc) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + pdPlotHeight/2)) .attr("text-anchor", "middle") - .text("average prediction"); + .text(pdAxisTitle); var description = PD.append("g") .attr("transform", "translate(" + @@ -1652,7 +1659,7 @@ function pdCategoricalPlot(variableName, bData, yMinMax, yMean, desc) { (margin.top + pdPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("average prediction"); + .text(pdAxisTitle); var description = PD.append("g") .attr("transform", "translate(" + @@ -1827,7 +1834,7 @@ function adNumericalPlot(variableName, lData, mData, yMinMax, yMean, desc) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + adPlotHeight/2)) .attr("text-anchor", "middle") - .text("accumulated prediction"); + .text(adAxisTitle); // var description = AD.append("g") // .attr("transform", "translate(" + @@ -1981,7 +1988,7 @@ function adCategoricalPlot(variableName, bData, yMinMax, yMean, desc) { (margin.top + adPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("accumulated prediction"); + .text(adAxisTitle); // var description = AD.append("g") // .attr("transform", "translate(" + @@ -2063,7 +2070,7 @@ function rvNumericalPlot(variableName, xData, xMinMax, yMinMax) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + rvPlotHeight/2)) .attr("text-anchor", "middle") - .text("residuals"); + .text(rvAxisTitle); var y = d3.scaleLinear() .range([margin.top + rvPlotHeight, margin.top - additionalHeight]) @@ -2239,7 +2246,7 @@ function rvCategoricalPlot(variableName, xData, xMinMax, yMinMax) { (margin.top + rvPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("residuals"); + .text(rvAxisTitle); } function fdNumericalPlot(variableName, dData, mData, nBin) { @@ -2265,7 +2272,7 @@ function fdNumericalPlot(variableName, dData, mData, nBin) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + fdPlotHeight/2)) .attr("text-anchor", "middle") - .text("count"); + .text(fdAxisTitle); var y = d3.scaleLinear() .range([margin.top + fdPlotHeight - 5, margin.top]); @@ -2485,7 +2492,7 @@ function fdCategoricalPlot(variableName, dData, xMinMax, mData) { (margin.top + fdPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("count"); + .text(fdAxisTitle); } function tvNumericalPlot(variableName, xData, xMinMax, yMinMax) { @@ -2536,7 +2543,7 @@ function tvNumericalPlot(variableName, xData, xMinMax, yMinMax) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + tvPlotHeight/2)) .attr("text-anchor", "middle") - .text("target"); + .text(tvAxisTitle); if (IS_TARGET_BINARY) { @@ -2711,7 +2718,7 @@ function tvCategoricalPlot(variableName, xData, xMinMax, yMinMax) { var xTitle; if (IS_TARGET_BINARY) { - xTitle = "average target" + xTitle = tvAxisTitle == "target" ? "average target" : tvAxisTitle; // find 5 nice ticks with max and min - do better than d3 var tickValues = getTickValues(x.domain()); @@ -2763,7 +2770,7 @@ function tvCategoricalPlot(variableName, xData, xMinMax, yMinMax) { .attr("x2", x(0)) .attr("y2", maximumY + y.bandwidth()); } else { - xTitle = "target"; + xTitle = tvAxisTitle; var xAxis = d3.axisBottom(x) .ticks(5) @@ -2932,7 +2939,7 @@ function atNumericalPlot(variableName, xData, xMinMax, yMinMax, yMean) { .attr("y", margin.left - margin.ytitle) .attr("x", -(margin.top + atPlotHeight/2)) .attr("text-anchor", "middle") - .text("average target"); + .text(atAxisTitle); AT.append("line") .attr("class", "interceptLine") @@ -3054,7 +3061,7 @@ function atCategoricalPlot(variableName, xData, xMinMax, yMinMax, yMean) { (margin.top + atPlotHeight + 45) + ")") .attr("class", "axisTitle") .attr("text-anchor", "middle") - .text("average target"); + .text(atAxisTitle); var bars = AT.selectAll() .data(xData) diff --git a/inst/d3js/modelStudio.js b/inst/d3js/modelStudio.js index ed284c1..f9bf3f4 100644 --- a/inst/d3js/modelStudio.js +++ b/inst/d3js/modelStudio.js @@ -46,17 +46,20 @@ var TIME = options.time, defaultColor = options.default_color, bdTitle = options.bd_title, bdSubtitle = options.bd_subtitle || subTitle, + bdAxisTitle = options.bd_axis_title, bdBarWidth = options.bd_bar_width || barWidth, bdPositiveColor = options.bd_positive_color || positiveColor, bdNegativeColor = options.bd_negative_color || negativeColor, bdDefaultColor = options.bd_default_color || defaultColor, svTitle = options.sv_title, svSubtitle = options.sv_subtitle || subTitle, + svAxisTitle = options.sv_axis_title, svBarWidth = options.sv_bar_width || barWidth, svPositiveColor = options.sv_positive_color || positiveColor, svNegativeColor = options.sv_negative_color || negativeColor, cpTitle = options.cp_title, cpSubtitle = options.cp_subtitle || subTitle, + cpAxisTitle = options.cp_axis_title, cpBarWidth = options.cp_bar_width || barWidth, cpLineSize = options.cp_line_size || lineSize, cpPointSize = options.cp_point_size || pointSize, @@ -65,34 +68,41 @@ var TIME = options.time, cpPointColor = options.cp_point_color || pointColor, fiTitle = options.fi_title, fiSubtitle = options.fi_subtitle || subTitle, + fiAxisTitle = options.fi_axis_title, fiBarWidth = options.fi_bar_width || barWidth, fiBarColor = options.fi_bar_color || barColor, pdTitle = options.pd_title, pdSubtitle = options.pd_subtitle || subTitle, + pdAxisTitle = options.pd_axis_title, pdBarWidth = options.pd_bar_width || barWidth, pdLineSize = options.pd_line_size || lineSize, pdBarColor = options.pd_bar_color || barColor, pdLineColor = options.pd_line_color || lineColor, adTitle = options.ad_title, adSubtitle = options.ad_subtitle || subTitle, + adAxisTitle = options.ad_axis_title, adBarWidth = options.ad_bar_width || barWidth, adLineSize = options.ad_line_size || lineSize, adBarColor = options.ad_bar_color || barColor, adLineColor = options.ad_line_color || lineColor, rvTitle = options.rv_title, rvSubtitle = options.rv_subtitle || subTitle, + rvAxisTitle = options.rv_axis_title, rvPointSize = options.rv_point_size || pointSize, rvPointColor = options.tv_point_color || pointColor, fdTitle = options.fd_title, fdSubtitle = options.fd_subtitle || subTitle, + fdAxisTitle = options.fd_axis_title, fdBarWidth = options.fd_bar_width || barWidth, fdBarColor = options.fd_bar_color || barColor, tvTitle = options.tv_title, tvSubtitle = options.tv_subtitle || subTitle, + tvAxisTitle = options.tv_axis_title, tvPointSize = options.tv_point_size || pointSize, tvPointColor = options.tv_point_color || pointColor, atTitle = options.at_title, atSubtitle = options.at_subtitle || subTitle, + atAxisTitle = options.at_axis_title, atBarWidth = options.at_bar_width || barWidth, atLineSize = options.at_line_size || lineSize, atPointSize = options.at_point_size || pointSize, diff --git a/man/ms_options.Rd b/man/ms_options.Rd index c443af4..ccc1252 100644 --- a/man/ms_options.Rd +++ b/man/ms_options.Rd @@ -42,18 +42,19 @@ ignored when \code{scale_plot = TRUE}.} \item{default_color}{\code{#371ea3} for Break Down bar and highlighted line.} } } -\subsection{Plot specific options:}{ +\subsection{Plot-specific options:}{ \code{**} is a two letter code unique to each plot, might be one of \code{[bd,sv,cp,fi,pd,ad,rv,fd,tv,at]}.\cr \describe{ -\item{**_title}{Plot specific title. Default varies.} -\item{**_subtitle}{Plot specific subtitle. Default is \code{subtitle}.} -\item{**_bar_width}{Plot specific width of bars. Default is \code{bar_width}, +\item{**_title}{Plot-specific title. Default varies.} +\item{**_subtitle}{Plot-specific subtitle. Default is \code{subtitle}.} +\item{**_axis_title}{Plot-specific axis title. Default varies.} +\item{**_bar_width}{Plot-specific width of bars. Default is \code{bar_width}, ignored when \code{scale_plot = TRUE}.} -\item{**_line_size}{\code{line_size} Plot specific width of lines. Default is \code{line_size}.} -\item{**_point_size}{Plot specific point radius. Default is \code{point_size}.} -\item{**_*_color}{Plot specific \code{[bar,line,point]} color. Default is \code{[bar,line,point]_color}.} +\item{**_line_size}{Plot-specific width of lines. Default is \code{line_size}.} +\item{**_point_size}{Plot-specific point radius. Default is \code{point_size}.} +\item{**_*_color}{Plot-specific \code{[bar,line,point]} color. Default is \code{[bar,line,point]_color}.} } } } diff --git a/man/ms_update_options.Rd b/man/ms_update_options.Rd index 473373c..372823b 100644 --- a/man/ms_update_options.Rd +++ b/man/ms_update_options.Rd @@ -43,18 +43,19 @@ ignored when \code{scale_plot = TRUE}.} \item{default_color}{\code{#371ea3} for Break Down bar and highlighted line.} } } -\subsection{Plot specific options:}{ +\subsection{Plot-specific options:}{ \code{**} is a two letter code unique to each plot, might be one of \code{[bd,sv,cp,fi,pd,ad,rv,fd,tv,at]}.\cr \describe{ -\item{**_title}{Plot specific title. Default varies.} -\item{**_subtitle}{Plot specific subtitle. Default is \code{subtitle}.} -\item{**_bar_width}{Plot specific width of bars. Default is \code{bar_width}, +\item{**_title}{Plot-specific title. Default varies.} +\item{**_subtitle}{Plot-specific subtitle. Default is \code{subtitle}.} +\item{**_axis_title}{Plot-specific axis title. Default varies.} +\item{**_bar_width}{Plot-specific width of bars. Default is \code{bar_width}, ignored when \code{scale_plot = TRUE}.} -\item{**_line_size}{\code{line_size} Plot specific width of lines. Default is \code{line_size}.} -\item{**_point_size}{Plot specific point radius. Default is \code{point_size}.} -\item{**_*_color}{Plot specific \code{[bar,line,point]} color. Default is \code{[bar,line,point]_color}.} +\item{**_line_size}{Plot-specific width of lines. Default is \code{line_size}.} +\item{**_point_size}{Plot-specific point radius. Default is \code{point_size}.} +\item{**_*_color}{Plot-specific \code{[bar,line,point]} color. Default is \code{[bar,line,point]_color}.} } } }