Skip to content

Commit

Permalink
works?
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Oct 5, 2023
1 parent 6c6bdcc commit 1dc2071
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
11 changes: 11 additions & 0 deletions R/sanitize_type.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ sanitize_type <- function(model, type, calling_function = "raw") {
return(type)
}

# mlr3
if (inherits(model, "Learner")) {
valid <- setdiff(model$predict_types, "se")
checkmate::assert_choice(type, choices = valid, null.ok = TRUE)
return(type)
}

if (is.null(type)) {
return(type)
}

checkmate::assert_character(type, len = 1, null.ok = TRUE)
cl <- class(model)[1]
if (!cl %in% type_dictionary$class) {
Expand Down
47 changes: 35 additions & 12 deletions book/articles/machine_learning.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,6 @@ avg_comparisons(
newdata = bikes)
```


```{r}
plot_predictions(forest, condition = "temp")
```

# tidymodels

`marginaleffects` also supports the `tidymodels` machine learning framework. When the underlying engine used by `tidymodels` to train the model is itself supported as a standalone package by `marginaleffects`, we can obtain estimates of uncertainty estimates:
Expand All @@ -109,19 +104,47 @@ mod <- linear_reg(mode = "regression") |>
avg_comparisons(mod, newdata = bikes, type = "response")
```

We can also plot the results as usual:

```{r, warning = FALSE}
plot_predictions(mod, condition = "temp", points = .2)
```

When the underlying engine that `tidymodels` uses to fit the model is not supported by `marginaleffects` as a standalone model, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model:

```{r}
forest_tidy <- rand_forest(mode = "regression") |>
set_engine("ranger") |>
fit(count ~ ., data = bikes)
avg_comparisons(forest_tidy, newdata = bikes, type = "numeric")
```

# Plot

plot_predictions(forest_tidy, newdata = bikes, by = "temp", type = "numeric")
We can plot the results using the standard `marginaleffects` helpers. For example, to plot predictions, we can do:

```{r}
plot_predictions(forest, condition = "temp", newdata = bikes)
```

As documented in `?plot_predictions`, using `condition="temp"` is equivalent to creating an equally-spaced grid of `temp` values, and holding all other predictors at their means or modes. In other words, it is equivalent to:

```{r}
#| eval: false
d <- datagrid(temp = seq(min(bikes$temp), max(bikes$temp), length.out = 100), newdata = bikes)
p <- predict(forest, newdata = d)
plot(d$temp, p, type = "l")
```

Alternatively, we could plot "marginal" predictions, where replicate the full dataset once for every value of `temp`, and then average the predicted values over each value of the x-axis:


```{r}
d <- datagridcf(newdata = bikes, temp = unique)
plot_predictions(forest, by = "temp", newdata = d)
```

Of course, we can customize the plot using all the standard `ggplot2` functions:

```{r}
plot_predictions(forest, by = "temp", newdata = d) +
geom_point(data = bikes, aes(x = temp, y = count), alpha = 0.1) +
geom_smooth(data = bikes, aes(x = temp, y = count), se = FALSE, color = "orange") +
labs(x = "Temperature (Celcius)", y = "Predicted number of bikes rented per hour",
title = "Black: random forest predictions. Green: LOESS smoother.") +
theme_bw()
```
1 change: 0 additions & 1 deletion inst/tinytest/test-pkg-mlr3verse.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ data("bikes", package = "fmeffects")
task <- as_task_regr(x = bikes, id = "bikes", target = "count")
forest <- lrn("regr.ranger")$train(task)


# Plot predictions
p <- plot_predictions(forest, condition = "temp", type = "response", newdata = bikes)
expect_inherits(p, "gg")
Expand Down

0 comments on commit 1dc2071

Please sign in to comment.