Skip to content

Commit

Permalink
vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Oct 6, 2023
1 parent 48dd20f commit 0b0e8d9
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 36 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: marginaleffects
Title: Predictions, Comparisons, Slopes, Marginal Means, and Hypothesis Tests
Version: 0.15.1.9011
Version: 0.15.1.9012
Authors@R:
c(person(given = "Vincent",
family = "Arel-Bundock",
Expand Down Expand Up @@ -160,6 +160,7 @@ Suggests:
withr,
workflows,
yaml,
xgboost,
testthat (>= 3.0.0)
Collate:
'RcppExports.R'
Expand Down
143 changes: 108 additions & 35 deletions book/articles/machine_learning.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
title: "Machine Learning"
---


`marginaleffects` offers several "model-agnostic" functions to interpret statistical and machine learning models. This vignette highlights how the package can be used to extract meaningful insights from models trained using the `mlr3` and `tidymodels` frameworks.

The features in this vignette require version 0.16.0 or `marginaleffects`, or the development version which can be installed from Github:
Expand All @@ -20,10 +19,102 @@ library("marginaleffects")
library("fmeffects")
library("ggplot2")
library("mlr3verse")
library("modelsummary")
library("tidymodels") |> suppressPackageStartupMessages()
options(width = 10000)
```


## `tidymodels`

`marginaleffects` also supports the `tidymodels` machine learning framework. When the underlying engine used by `tidymodels` to train the model is itself supported as a standalone package by `marginaleffects`, we can obtain both estimates and their standard errors:

```{r, message = FALSE}
#| warning: false
library(tidymodels)
penguins <- modeldata::penguins |>
na.omit() |>
select(sex, island, species, bill_length_mm)
mod <- linear_reg(mode = "regression") |>
set_engine("lm") |>
fit(bill_length_mm ~ ., data = penguins)
avg_comparisons(mod, type = "numeric", newdata = penguins)
avg_predictions(mod, type = "numeric", newdata = penguins, by = "island")
```

When the underlying engine that `tidymodels` uses to fit the model is not supported by `marginaleffects` as a standalone model, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model:

```{r}
mod <- rand_forest(mode = "regression") |>
set_engine("ranger") |>
fit(bill_length_mm ~ ., data = penguins)
avg_comparisons(mod, newdata = penguins, type = "numeric")
```

### Workflows

`tidymodels` "workflows" are a convenient way to train a model while applying a series of pre-processing steps to the data. `marginaleffects` supports workflows out of the box. First, let's consider a simple regression task:

```{r}
penguins <- modeldata::penguins |>
na.omit() |>
select(sex, island, species, bill_length_mm)
mod <- penguins |>
recipe(bill_length_mm ~ island + species + sex, data = _) |>
step_dummy(all_nominal_predictors()) |>
workflow(spec = linear_reg(mode = "regression", engine = "glm")) |>
fit(penguins)
avg_comparisons(mod, newdata = penguins, type = "numeric")
```

Now, we run a classification task instead, and plot the predicted probabilities:

```{r}
mod <- penguins |>
recipe(sex ~ island + species + bill_length_mm, data = _) |>
step_dummy(all_nominal_predictors()) |>
workflow(spec = logistic_reg(mode = "classification", engine = "glm")) |>
fit(penguins)
plot_predictions(
mod,
condition = c("bill_length_mm", "group"),
newdata = penguins,
type = "prob")
```

Finally, let's consider a more complex task, where we train several models and summarize them in a table using `modelsummary`:

```{r, error=T}
library(modelsummary)
recipe <- penguins |>
recipe(sex ~ ., data = _) |>
step_ns(bill_length_mm, deg_free = 4) |>
step_dummy(all_nominal_predictors())
models <- list(
logit = logistic_reg(mode = "classification", engine = "glm"),
forest = rand_forest(mode = "classification", engine = "ranger"),
xgb = boost_tree(mode = "classification", engine = "xgboost")
)
lapply(models, \(x) {
recipe |>
workflow(spec = x) |>
fit(penguins) |>
avg_comparisons(newdata = penguins, type = "prob") }) |>
modelsummary(shape = term + contrast + group ~ model)
```


## `mlr3`

`mlr3` is a machine learning framework for `R`. It makes it possible for users to train a wide range of models, including linear models, random forests, gradient boosting machines, and neural networks.
Expand All @@ -37,7 +128,7 @@ task <- as_task_regr(x = bikes, id = "bikes", target = "count")
forest <- lrn("regr.ranger")$train(task)
```

As described in other vignettes, we can use the `avg_comparisons()` function to compute the average change in predicted outcome that is associated with a change in each feature:
As described in other vignettes, we can use the [`avg_comparisons()`](reference/comparisons.html) function to compute the average change in predicted outcome that is associated with a change in each feature:

```{r}
avg_comparisons(forest, newdata = bikes)
Expand All @@ -57,7 +148,10 @@ hi <- transform(bikes, temp = temp + 0.5)
mean(predict(forest, newdata = hi) - predict(forest, newdata = lo))
```

As the code above makes clear, the `avg_comparisons()` computes the effect of a "centered" change on the outcome. If we want to compute a "Forward Marginal Effect" instead, we can call:

### `fmeffects`: Forward or centered effects

As the code above makes clear, the [`avg_comparisons()`](reference/comparisons.html) computes the effect of a "centered" change on the outcome. If we want to compute a "Forward Marginal Effect" instead, we can call:

```{r}
avg_comparisons(
Expand Down Expand Up @@ -87,46 +181,26 @@ avg_comparisons(
newdata = bikes)
```

# `tidymodels`

`marginaleffects` also supports the `tidymodels` machine learning framework. When the underlying engine used by `tidymodels` to train the model is itself supported as a standalone package by `marginaleffects`, we can obtain both estimates and their standard errors:

```{r}
#| include: false
library(tidymodels)
```

```{r, message = FALSE}
#| warning: false
library(tidymodels)
mod <- linear_reg(mode = "regression") |>
set_engine("lm") |>
fit(count ~ ., data = bikes)
avg_comparisons(mod, type = "numeric", newdata = bikes)
```

When the underlying engine that `tidymodels` uses to fit the model is not supported by `marginaleffects` as a standalone model, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model:

```{r}
forest_tidy <- rand_forest(mode = "regression") |>
set_engine("ranger") |>
fit(count ~ ., data = bikes)

avg_comparisons(forest_tidy, newdata = bikes, type = "numeric")
```

# Plot
## Plots

We can plot the results using the standard `marginaleffects` helpers. For example, to plot predictions, we can do:

```{r}
library(mlr3verse)
data("bikes", package = "fmeffects")
task <- as_task_regr(x = bikes, id = "bikes", target = "count")
forest <- lrn("regr.ranger")$train(task)
plot_predictions(forest, condition = "temp", newdata = bikes)
```

As documented in `?plot_predictions`, using `condition="temp"` is equivalent to creating an equally-spaced grid of `temp` values, and holding all other predictors at their means or modes. In other words, it is equivalent to:

```{r}
#| include: false
d <- datagrid(temp = seq(min(bikes$temp), max(bikes$temp), length.out = 100), newdata = bikes)
```
```{r}
#| eval: false
d <- datagrid(temp = seq(min(bikes$temp), max(bikes$temp), length.out = 100), newdata = bikes)
Expand All @@ -136,7 +210,6 @@ plot(d$temp, p, type = "l")

Alternatively, we could plot "marginal" predictions, where replicate the full dataset once for every value of `temp`, and then average the predicted values over each value of the x-axis:


```{r}
plot_predictions(forest, by = "temp", newdata = bikes)
```
Expand All @@ -150,4 +223,4 @@ plot_predictions(forest, by = "temp", newdata = d) +
labs(x = "Temperature (Celcius)", y = "Predicted number of bikes rented per hour",
title = "Black: random forest predictions. Green: LOESS smoother.") +
theme_bw()
```
```

0 comments on commit 0b0e8d9

Please sign in to comment.