From 2bb532ef32883e7fa002d403e3cc29dd463aee55 Mon Sep 17 00:00:00 2001 From: Vincent Arel-Bundock Date: Wed, 4 Oct 2023 11:13:41 -0400 Subject: [PATCH] plot doesn't work --- R/predictions.R | 2 +- R/sanitize_condition.R | 1 + book/articles/machine_learning.qmd | 8 ++++++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/R/predictions.R b/R/predictions.R index 14e0778d3..9204f261e 100644 --- a/R/predictions.R +++ b/R/predictions.R @@ -274,7 +274,7 @@ predictions <- function(model, # if type is NULL, we backtransform if relevant type_string <- sanitize_type(model = model, type = type, calling_function = "predictions") - if (type_string == "invlink(link)") { + if (identical(type_string, "invlink(link)")) { if (is.null(hypothesis)) { type_call <- "link" } else { diff --git a/R/sanitize_condition.R b/R/sanitize_condition.R index b264bbf3c..63ba6dd65 100644 --- a/R/sanitize_condition.R +++ b/R/sanitize_condition.R @@ -57,6 +57,7 @@ sanitize_condition <- function(model, condition, variables = NULL, modeldata = N respname <- insight::find_response(model) flag <- checkmate::check_true(all(names(condition) %in% c(colnames(dat), "group"))) + browser() if (!isTRUE(flag)) { msg <- sprintf("Entries in the `condition` argument must be element of: %s", paste(colnames(dat), collapse = ", ")) diff --git a/book/articles/machine_learning.qmd b/book/articles/machine_learning.qmd index fdfcb4222..96e93d9ec 100644 --- a/book/articles/machine_learning.qmd +++ b/book/articles/machine_learning.qmd @@ -88,6 +88,12 @@ mod <- linear_reg(mode = "regression") |> avg_comparisons(mod, newdata = bikes, type = "response") ``` +We can also plot the results as usual: + +```{r, warning = FALSE} +plot_predictions(mod, condition = "temp", points = .2) +``` + When the engine is not supported by `marginaleffects`, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model: ```{r} @@ -95,4 +101,6 @@ forest_tidy <- rand_forest(mode = "regression") |> set_engine("ranger") |> fit(count ~ ., data = bikes) avg_comparisons(forest_tidy, newdata = bikes, type = "numeric") + +plot_predictions(forest_tidy, newdata = bikes, by = "temp", type = "numeric") ```