Skip to content

Commit

Permalink
typos
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Oct 6, 2023
1 parent 07c1081 commit 9e7da74
Showing 1 changed file with 3 additions and 59 deletions.
62 changes: 3 additions & 59 deletions book/articles/machine_learning.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,63 +49,7 @@ avg_predictions(mod, type = "numeric", newdata = penguins, by = "island")
When the underlying engine that `tidymodels` uses to fit the model is not supported by `marginaleffects` as a standalone model, we can also obtain correct results, but no uncertainy estimates. Here is a random forest model:

```{r}
mod <- rand_forest(mode = "regression") |>
set_engine("ranger") |>
fit(bill_length_mm ~ ., data = penguins)
avg_comparisons(mod, newdata = penguins, type = "numeric")
```

### Workflows

`tidymodels` "workflows" are a convenient way to train a model while applying a series of pre-processing steps to the data. `marginaleffects` supports workflows out of the box. First, let's consider a simple regression task:

```{r}
penguins <- modeldata::penguins |>
na.omit() |>
select(sex, island, species, bill_length_mm)
mod <- penguins |>
recipe(bill_length_mm ~ island + species + sex, data = _) |>
step_dummy(all_nominal_predictors()) |>
workflow(spec = linear_reg(mode = "regression", engine = "glm")) |>
fit(penguins)
avg_comparisons(mod, newdata = penguins, type = "numeric")
```

Now, we run a classification task instead, and plot the predicted probabilities:

```{r}
mod <- penguins |>
recipe(sex ~ island + species + bill_length_mm, data = _) |>
step_dummy(all_nominal_predictors()) |>
workflow(spec = logistic_reg(mode = "classification", engine = "glm")) |>
fit(penguins)
plot_predictions(
mod,
condition = c("bill_length_mm", "group"),
newdata = penguins,
type = "prob")
```

Finally, let's consider a more complex task, where we train several models and summarize them in a table using `modelsummary`:

```{r, error=T}
library(modelsummary)
# pre-processing
pre <- penguins |>
recipe(sex ~ ., data = _) |>
step_ns(bill_length_mm, deg_free = 4) |>
step_dummy(all_nominal_predictors())
# modelling strategies
models <- list(
"Logit" = logistic_reg(mode = "classification", engine = "glm"),
"Random Forest" = rand_forest(mode = "classification", engine = "ranger"),
"XGBooost" = boost_tree(mode = "classification", engine = "xgboost")
mod <- rand_forest(mode = "regression") |oost" = boost_tree(mode = "classification", engine = "xgboost")
)
# fit to data
Expand Down Expand Up @@ -232,7 +176,7 @@ Of course, we can customize the plot using all the standard `ggplot2` functions:
plot_predictions(forest, by = "temp", newdata = d) +
geom_point(data = bikes, aes(x = temp, y = count), alpha = 0.1) +
geom_smooth(data = bikes, aes(x = temp, y = count), se = FALSE, color = "orange") +
labs(x = "Temperature (Celcius)", y = "Predicted number of bikes rented per hour",
title = "Black: random forest predictions. Green: LOESS smoother.") +
labs(x = "Temperature (Celsius)", y = "Predicted number of bikes rented per hour",
title = "Black: random forest predictions. Orange: LOESS smoother.") +
theme_bw()
```

0 comments on commit 9e7da74

Please sign in to comment.