Skip to content

Commit

Permalink
plots seem to work
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Oct 5, 2023
1 parent 5409f0f commit 6c6bdcc
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 19 deletions.
18 changes: 10 additions & 8 deletions R/plot_comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,7 @@ plot_comparisons <- function(model,

# order of the first few paragraphs is important
scall <- rlang::enquo(newdata)
if (!is.null(condition) && !is.null(newdata)) {
insight::format_error("The `condition` and `newdata` arguments cannot be used simultaneously.")
}
newdata <- sanitize_newdata_call(scall, newdata, model)
if (!is.null(newdata) && is.null(by)) {
insight::format_error("The `newdata` argument requires a `by` argument.")
}
if (!is.null(wts) && is.null(by)) {
insight::format_error("The `wts` argument requires a `by` argument.")
}
Expand All @@ -89,9 +83,18 @@ plot_comparisons <- function(model,
checkmate::check_list(variables, names = "unique"),
.var.name = "variables")

modeldata <- get_modeldata(
model,
additional_variables = c(names(condition), by),
wts = wts)

# mlr3 and tidymodels
if (is.null(modeldata) || nrow(modeldata) == 0) {
modeldata <- newdata
}

# conditional
if (!is.null(condition)) {
modeldata <- get_modeldata(model, additional_variables = names(condition), wts = wts)
condition <- sanitize_condition(model, condition, variables, modeldata = modeldata)
v_x <- condition$condition1
v_color <- condition$condition2
Expand All @@ -114,7 +117,6 @@ plot_comparisons <- function(model,

# marginal
if (!is.null(by)) {
modeldata <- get_modeldata(model, additional_variables = by, wts = wts)
newdata <- sanitize_newdata(
model = model,
newdata = newdata,
Expand Down
25 changes: 17 additions & 8 deletions R/plot_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,7 @@ plot_predictions <- function(model,

# order of the first few paragraphs is important
scall <- rlang::enquo(newdata)
if (!is.null(condition) && !is.null(newdata)) {
insight::format_error("The `condition` and `newdata` arguments cannot be used simultaneously.")
}
newdata <- sanitize_newdata_call(scall, newdata, model)
if (!is.null(newdata) && is.null(by)) {
insight::format_error("The `newdata` argument requires a `by` argument.")
}
if (!is.null(wts) && is.null(by)) {
insight::format_error("The `wts` argument requires a `by` argument.")
}
Expand All @@ -94,9 +88,18 @@ plot_predictions <- function(model,
insight::format_error(msg)
}

modeldata <- get_modeldata(
model,
additional_variables = c(names(condition), by),
wts = wts)

# mlr3 and tidymodels
if (is.null(modeldata) || nrow(modeldata) == 0) {
modeldata <- newdata
}

# conditional
if (!is.null(condition)) {
modeldata <- get_modeldata(model, additional_variables = names(condition), wts = wts)
condition <- sanitize_condition(model, condition, variables = NULL, modeldata = modeldata)
v_x <- condition$condition1
v_color <- condition$condition2
Expand All @@ -116,13 +119,19 @@ plot_predictions <- function(model,
# marginal
if (!isFALSE(by) && !is.null(by)) { # switched from NULL above
condition <- NULL
modeldata <- get_modeldata(model, additional_variables = by, wts = wts)

newdata <- sanitize_newdata(
model = model,
newdata = newdata,
modeldata = modeldata,
by = by,
wts = wts)

# tidymodels & mlr3
if (is.null(modeldata)) {
modeldata <- newdata
}

datplot <- predictions(
model,
by = by,
Expand Down
5 changes: 5 additions & 0 deletions R/sanitize_condition.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ sanitize_condition <- function(model, condition, variables = NULL, modeldata = N
}
}

# mlr3 and tidymodels are not supported by `insight::find_variables()`, so we need to create a grid based on all the variables supplied in `newdata`
if (inherits(at_list$model, "Learner") || inherits(at_list$model, "model_fit")) {
at_list$model <- NULL
}

# create data
nd <- do.call("datagrid", at_list)

Expand Down
15 changes: 12 additions & 3 deletions book/articles/machine_learning.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ title: "Machine Learning"
---


`marginaleffects` offers several "model-agnostic" functions to interpret statistical and machine learning models. This vignette highlights how the package can be used to extract meaningful insights from the results of models trained using the `mlr3` and `tidymodels` framework.
`marginaleffects` offers several "model-agnostic" functions to interpret statistical and machine learning models. This vignette highlights how the package can be used to extract meaningful insights from the results of models trained using the `mlr3` and `tidymodels` frameworks.

The features in this vignette require version 0.16.0 or `marginaleffects`, or the development version which can be installed from Github:

Expand All @@ -16,14 +16,15 @@ Make sure to restart `R` after installation. Then, load a few libraries:
```{r}
#| message: false
#| warning: false
options(width = 10000)
library("marginaleffects")
library("fmeffects")
library("ggplot2")
library("mlr3verse")
library("tidymodels") |> suppressPackageStartupMessages()
options(width = 10000)
```
```{r}
#| include: false
pkgload::load_all()
```

Expand All @@ -43,8 +44,11 @@ forest <- lrn("regr.ranger")$train(task)
As described in other vignettes, we can use the `avg_comparisons()` function to compute the average change in predicted outcome that is associated with a change in each feature:

```{r}
avg_comparisons(forest, newdata = bikes)
```
```{r}
#| include: false
cmp <- avg_comparisons(forest, newdata = bikes)
cmp
```

These results are easy to interpret: An increase of 1 degree Celsius in the temperature is associated with an increase of `r sprintf("%.3f", cmp$estimate[cmp$term == "temp"])` bikes rented per hour.
Expand Down Expand Up @@ -87,6 +91,11 @@ avg_comparisons(
newdata = bikes)
```


```{r}
plot_predictions(forest, condition = "temp")
```

# tidymodels

`marginaleffects` also supports the `tidymodels` machine learning framework. When the underlying engine used by `tidymodels` to train the model is itself supported as a standalone package by `marginaleffects`, we can obtain estimates of uncertainty estimates:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ data("bikes", package = "fmeffects")
task <- as_task_regr(x = bikes, id = "bikes", target = "count")
forest <- lrn("regr.ranger")$train(task)


# Plot predictions
p <- plot_predictions(forest, condition = "temp", type = "response", newdata = bikes)
expect_inherits(p, "gg")

# Centered difference
cmp <- avg_comparisons(forest, newdata = bikes)
expect_inherits(cmp, "comparisons")
Expand Down

0 comments on commit 6c6bdcc

Please sign in to comment.