diff --git a/R/plot_comparisons.R b/R/plot_comparisons.R index bcb48ac6d..2d7e14e87 100644 --- a/R/plot_comparisons.R +++ b/R/plot_comparisons.R @@ -66,13 +66,7 @@ plot_comparisons <- function(model, # order of the first few paragraphs is important scall <- rlang::enquo(newdata) - if (!is.null(condition) && !is.null(newdata)) { - insight::format_error("The `condition` and `newdata` arguments cannot be used simultaneously.") - } newdata <- sanitize_newdata_call(scall, newdata, model) - if (!is.null(newdata) && is.null(by)) { - insight::format_error("The `newdata` argument requires a `by` argument.") - } if (!is.null(wts) && is.null(by)) { insight::format_error("The `wts` argument requires a `by` argument.") } @@ -89,9 +83,18 @@ plot_comparisons <- function(model, checkmate::check_list(variables, names = "unique"), .var.name = "variables") + modeldata <- get_modeldata( + model, + additional_variables = c(names(condition), by), + wts = wts) + + # mlr3 and tidymodels + if (is.null(modeldata) || nrow(modeldata) == 0) { + modeldata <- newdata + } + # conditional if (!is.null(condition)) { - modeldata <- get_modeldata(model, additional_variables = names(condition), wts = wts) condition <- sanitize_condition(model, condition, variables, modeldata = modeldata) v_x <- condition$condition1 v_color <- condition$condition2 @@ -114,7 +117,6 @@ plot_comparisons <- function(model, # marginal if (!is.null(by)) { - modeldata <- get_modeldata(model, additional_variables = by, wts = wts) newdata <- sanitize_newdata( model = model, newdata = newdata, diff --git a/R/plot_predictions.R b/R/plot_predictions.R index 785f0e43e..47ca0d037 100644 --- a/R/plot_predictions.R +++ b/R/plot_predictions.R @@ -75,13 +75,7 @@ plot_predictions <- function(model, # order of the first few paragraphs is important scall <- rlang::enquo(newdata) - if (!is.null(condition) && !is.null(newdata)) { - insight::format_error("The `condition` and `newdata` arguments cannot be used simultaneously.") - } newdata <- sanitize_newdata_call(scall, newdata, model) - if (!is.null(newdata) && is.null(by)) { - insight::format_error("The `newdata` argument requires a `by` argument.") - } if (!is.null(wts) && is.null(by)) { insight::format_error("The `wts` argument requires a `by` argument.") } @@ -94,9 +88,18 @@ plot_predictions <- function(model, insight::format_error(msg) } + modeldata <- get_modeldata( + model, + additional_variables = c(names(condition), by), + wts = wts) + + # mlr3 and tidymodels + if (is.null(modeldata) || nrow(modeldata) == 0) { + modeldata <- newdata + } + # conditional if (!is.null(condition)) { - modeldata <- get_modeldata(model, additional_variables = names(condition), wts = wts) condition <- sanitize_condition(model, condition, variables = NULL, modeldata = modeldata) v_x <- condition$condition1 v_color <- condition$condition2 @@ -116,13 +119,19 @@ plot_predictions <- function(model, # marginal if (!isFALSE(by) && !is.null(by)) { # switched from NULL above condition <- NULL - modeldata <- get_modeldata(model, additional_variables = by, wts = wts) + newdata <- sanitize_newdata( model = model, newdata = newdata, modeldata = modeldata, by = by, wts = wts) + + # tidymodels & mlr3 + if (is.null(modeldata)) { + modeldata <- newdata + } + datplot <- predictions( model, by = by, diff --git a/R/sanitize_condition.R b/R/sanitize_condition.R index b264bbf3c..27ff93bf3 100644 --- a/R/sanitize_condition.R +++ b/R/sanitize_condition.R @@ -149,6 +149,11 @@ sanitize_condition <- function(model, condition, variables = NULL, modeldata = N } } + # mlr3 and tidymodels are not supported by `insight::find_variables()`, so we need to create a grid based on all the variables supplied in `newdata` + if (inherits(at_list$model, "Learner") || inherits(at_list$model, "model_fit")) { + at_list$model <- NULL + } + # create data nd <- do.call("datagrid", at_list) diff --git a/book/articles/machine_learning.qmd b/book/articles/machine_learning.qmd index 5702ae7bd..166f57160 100644 --- a/book/articles/machine_learning.qmd +++ b/book/articles/machine_learning.qmd @@ -3,7 +3,7 @@ title: "Machine Learning" --- -`marginaleffects` offers several "model-agnostic" functions to interpret statistical and machine learning models. This vignette highlights how the package can be used to extract meaningful insights from the results of models trained using the `mlr3` and `tidymodels` framework. +`marginaleffects` offers several "model-agnostic" functions to interpret statistical and machine learning models. This vignette highlights how the package can be used to extract meaningful insights from the results of models trained using the `mlr3` and `tidymodels` frameworks. The features in this vignette require version 0.16.0 or `marginaleffects`, or the development version which can be installed from Github: @@ -16,14 +16,15 @@ Make sure to restart `R` after installation. Then, load a few libraries: ```{r} #| message: false #| warning: false -options(width = 10000) library("marginaleffects") library("fmeffects") library("ggplot2") library("mlr3verse") library("tidymodels") |> suppressPackageStartupMessages() +options(width = 10000) ``` ```{r} +#| include: false pkgload::load_all() ``` @@ -43,8 +44,11 @@ forest <- lrn("regr.ranger")$train(task) As described in other vignettes, we can use the `avg_comparisons()` function to compute the average change in predicted outcome that is associated with a change in each feature: ```{r} +avg_comparisons(forest, newdata = bikes) +``` +```{r} +#| include: false cmp <- avg_comparisons(forest, newdata = bikes) -cmp ``` These results are easy to interpret: An increase of 1 degree Celsius in the temperature is associated with an increase of `r sprintf("%.3f", cmp$estimate[cmp$term == "temp"])` bikes rented per hour. @@ -87,6 +91,11 @@ avg_comparisons( newdata = bikes) ``` + +```{r} +plot_predictions(forest, condition = "temp") +``` + # tidymodels `marginaleffects` also supports the `tidymodels` machine learning framework. When the underlying engine used by `tidymodels` to train the model is itself supported as a standalone package by `marginaleffects`, we can obtain estimates of uncertainty estimates: diff --git a/inst/tinytest/test-pkg-ml3rverse.R b/inst/tinytest/test-pkg-mlr3verse.R similarity index 89% rename from inst/tinytest/test-pkg-ml3rverse.R rename to inst/tinytest/test-pkg-mlr3verse.R index 2a766dc21..577061645 100644 --- a/inst/tinytest/test-pkg-ml3rverse.R +++ b/inst/tinytest/test-pkg-mlr3verse.R @@ -9,6 +9,11 @@ data("bikes", package = "fmeffects") task <- as_task_regr(x = bikes, id = "bikes", target = "count") forest <- lrn("regr.ranger")$train(task) + +# Plot predictions +p <- plot_predictions(forest, condition = "temp", type = "response", newdata = bikes) +expect_inherits(p, "gg") + # Centered difference cmp <- avg_comparisons(forest, newdata = bikes) expect_inherits(cmp, "comparisons")