From 1dc2071959c5488a66ee357378b4303820094a64 Mon Sep 17 00:00:00 2001 From: Vincent Arel-Bundock Date: Thu, 5 Oct 2023 09:16:35 -0400 Subject: [PATCH] works? --- R/sanitize_type.R | 11 +++++++ book/articles/machine_learning.qmd | 47 ++++++++++++++++++++++-------- inst/tinytest/test-pkg-mlr3verse.R | 1 - 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/R/sanitize_type.R b/R/sanitize_type.R index c272ce33a..a7855c808 100644 --- a/R/sanitize_type.R +++ b/R/sanitize_type.R @@ -13,6 +13,17 @@ sanitize_type <- function(model, type, calling_function = "raw") { return(type) } + # mlr3 + if (inherits(model, "Learner")) { + valid <- setdiff(model$predict_types, "se") + checkmate::assert_choice(type, choices = valid, null.ok = TRUE) + return(type) + } + + if (is.null(type)) { + return(type) + } + checkmate::assert_character(type, len = 1, null.ok = TRUE) cl <- class(model)[1] if (!cl %in% type_dictionary$class) { diff --git a/book/articles/machine_learning.qmd b/book/articles/machine_learning.qmd index 166f57160..532ba2a80 100644 --- a/book/articles/machine_learning.qmd +++ b/book/articles/machine_learning.qmd @@ -91,11 +91,6 @@ avg_comparisons( newdata = bikes) ``` - -```{r} -plot_predictions(forest, condition = "temp") -``` - # tidymodels `marginaleffects` also supports the `tidymodels` machine learning framework. When the underlying engine used by `tidymodels` to train the model is itself supported as a standalone package by `marginaleffects`, we can obtain estimates of uncertainty estimates: @@ -109,12 +104,6 @@ mod <- linear_reg(mode = "regression") |> avg_comparisons(mod, newdata = bikes, type = "response") ``` -We can also plot the results as usual: - -```{r, warning = FALSE} -plot_predictions(mod, condition = "temp", points = .2) -``` - When the underlying engine that `tidymodels` uses to fit the model is not supported by `marginaleffects` as a standalone model, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model: ```{r} @@ -122,6 +111,40 @@ forest_tidy <- rand_forest(mode = "regression") |> set_engine("ranger") |> fit(count ~ ., data = bikes) avg_comparisons(forest_tidy, newdata = bikes, type = "numeric") +``` + +# Plot -plot_predictions(forest_tidy, newdata = bikes, by = "temp", type = "numeric") +We can plot the results using the standard `marginaleffects` helpers. For example, to plot predictions, we can do: + +```{r} +plot_predictions(forest, condition = "temp", newdata = bikes) ``` + +As documented in `?plot_predictions`, using `condition="temp"` is equivalent to creating an equally-spaced grid of `temp` values, and holding all other predictors at their means or modes. In other words, it is equivalent to: + +```{r} +#| eval: false +d <- datagrid(temp = seq(min(bikes$temp), max(bikes$temp), length.out = 100), newdata = bikes) +p <- predict(forest, newdata = d) +plot(d$temp, p, type = "l") +``` + +Alternatively, we could plot "marginal" predictions, where replicate the full dataset once for every value of `temp`, and then average the predicted values over each value of the x-axis: + + +```{r} +d <- datagridcf(newdata = bikes, temp = unique) +plot_predictions(forest, by = "temp", newdata = d) +``` + +Of course, we can customize the plot using all the standard `ggplot2` functions: + +```{r} +plot_predictions(forest, by = "temp", newdata = d) + + geom_point(data = bikes, aes(x = temp, y = count), alpha = 0.1) + + geom_smooth(data = bikes, aes(x = temp, y = count), se = FALSE, color = "orange") + + labs(x = "Temperature (Celcius)", y = "Predicted number of bikes rented per hour", + title = "Black: random forest predictions. Green: LOESS smoother.") + + theme_bw() +``` \ No newline at end of file diff --git a/inst/tinytest/test-pkg-mlr3verse.R b/inst/tinytest/test-pkg-mlr3verse.R index 577061645..a4da4040b 100644 --- a/inst/tinytest/test-pkg-mlr3verse.R +++ b/inst/tinytest/test-pkg-mlr3verse.R @@ -9,7 +9,6 @@ data("bikes", package = "fmeffects") task <- as_task_regr(x = bikes, id = "bikes", target = "count") forest <- lrn("regr.ranger")$train(task) - # Plot predictions p <- plot_predictions(forest, condition = "temp", type = "response", newdata = bikes) expect_inherits(p, "gg")