diff --git a/book/articles/machine_learning.qmd b/book/articles/machine_learning.qmd index 0e2b83474..7fa76b926 100644 --- a/book/articles/machine_learning.qmd +++ b/book/articles/machine_learning.qmd @@ -49,63 +49,7 @@ avg_predictions(mod, type = "numeric", newdata = penguins, by = "island") When the underlying engine that `tidymodels` uses to fit the model is not supported by `marginaleffects` as a standalone model, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model: ```{r} -mod <- rand_forest(mode = "regression") |> - set_engine("ranger") |> - fit(bill_length_mm ~ ., data = penguins) - -avg_comparisons(mod, newdata = penguins, type = "numeric") -``` - -### Workflows - -`tidymodels` "workflows" are a convenient way to train a model while applying a series of pre-processing steps to the data. `marginaleffects` supports workflows out of the box. First, let's consider a simple regression task: - -```{r} -penguins <- modeldata::penguins |> - na.omit() |> - select(sex, island, species, bill_length_mm) - -mod <- penguins |> - recipe(bill_length_mm ~ island + species + sex, data = _) |> - step_dummy(all_nominal_predictors()) |> - workflow(spec = linear_reg(mode = "regression", engine = "glm")) |> - fit(penguins) - -avg_comparisons(mod, newdata = penguins, type = "numeric") -``` - -Now, we run a classification task instead, and plot the predicted probabilities: - -```{r} -mod <- penguins |> - recipe(sex ~ island + species + bill_length_mm, data = _) |> - step_dummy(all_nominal_predictors()) |> - workflow(spec = logistic_reg(mode = "classification", engine = "glm")) |> - fit(penguins) - -plot_predictions( - mod, - condition = c("bill_length_mm", "group"), - newdata = penguins, - type = "prob") -``` - -Finally, let's consider a more complex task, where we train several models and summarize them in a table using `modelsummary`: - -```{r, error=T} -library(modelsummary) - -# pre-processing -pre <- penguins |> - recipe(sex ~ ., data = _) |> - step_ns(bill_length_mm, deg_free = 4) |> - step_dummy(all_nominal_predictors()) - -# modelling strategies -models <- list( - "Logit" = logistic_reg(mode = "classification", engine = "glm"), - "Random Forest" = rand_forest(mode = "classification", engine = "ranger"), - "XGBooost" = boost_tree(mode = "classification", engine = "xgboost") +mod <- rand_forest(mode = "regression") |oost" = boost_tree(mode = "classification", engine = "xgboost") ) # fit to data @@ -232,7 +176,7 @@ Of course, we can customize the plot using all the standard `ggplot2` functions: plot_predictions(forest, by = "temp", newdata = d) + geom_point(data = bikes, aes(x = temp, y = count), alpha = 0.1) + geom_smooth(data = bikes, aes(x = temp, y = count), se = FALSE, color = "orange") + - labs(x = "Temperature (Celcius)", y = "Predicted number of bikes rented per hour", - title = "Black: random forest predictions. Green: LOESS smoother.") + + labs(x = "Temperature (Celsius)", y = "Predicted number of bikes rented per hour", + title = "Black: random forest predictions. Orange: LOESS smoother.") + theme_bw() ```