diff --git a/book/articles/alternative_software.qmd b/book/articles/alternative_software.qmd index 8fcc5bb3c..72c6beb4d 100644 --- a/book/articles/alternative_software.qmd +++ b/book/articles/alternative_software.qmd @@ -767,6 +767,57 @@ data.frame(bm$ContrastSummary) avg_slopes(ls.fe, dpar = "sigma", re_formula = NA) ``` +## `fmeffects` + +The [`fmeffects` package](https://cran.r-project.org/package=fmeffects) is described as follows: + +> fmeffects: Model-Agnostic Interpretations with Forward Marginal Effects. Create local, regional, and global explanations for any machine learning model with forward marginal effects. You provide a model and data, and 'fmeffects' computes feature effects. The package is based on the theory in: C. A. Scholbeck, G. Casalicchio, C. Molnar, B. Bischl, and C. Heumann (2022) + +As the name says, this package is focused on "forward marginal effects" in the context of machine learning models estimated using the `mlr3` or `tidymodels` frameworks. Since version 0.16.0, `marginaleffects` also supports these machine learning frameworks, and it covers a superset of the `fmeffects` functionality. Consider a random forest model trained on the `bikes` data: + +```{r, message = FALSE, warning = FALSE} +library("mlr3verse") +library("fmeffects") +data("bikes", package = "fmeffects") +task <- as_task_regr(x = bikes, id = "bikes", target = "count") +forest <- lrn("regr.ranger")$train(task) +``` + +Now, we use the `avg_comparisons()` function to compute *centered* marginal effects: + +```{r} +avg_comparisons(forest, variables = "temp", newdata = bikes) +``` + +We call this quantity "centered" because it represents the average effect of a change of 1 unit in `temp` about the observed value, that is, a change from 0.5 below to 0.5 above: + +```{r} +lo <- transform(bikes, temp = temp - 0.5) +hi <- transform(bikes, temp = temp + 0.5) +mean(predict(forest, newdata = hi) - predict(forest, newdata = lo)) +``` + +As described in the [`comparisons()` vignette](comparisons.html), it is easy to estimate "backward", "centered" or "forward" differences by supplying an appropriate function to the `variables` argument. For example, here is how to compute "forward" marginal effects: + +```{r} +avg_comparisons( + forest, + variables = list("temp" = \(x) data.frame(x, x + 1)), + newdata = bikes) +``` + +This is equivalent to the key quantity reported by the `fmeffects` package: + +```{r} +fmeffects::fme( + model = forest, + data = bikes, + target = "count", + feature = "temp", + step.size = 1)$ame +``` + + ## `effects` The [`effects` package](https://cran.r-project.org/package=effects) was created by John Fox and colleagues. diff --git a/book/articles/machine_learning.qmd b/book/articles/machine_learning.qmd index 2eb93647a..327697e7f 100644 --- a/book/articles/machine_learning.qmd +++ b/book/articles/machine_learning.qmd @@ -131,28 +131,6 @@ avg_comparisons( newdata = bikes) ``` -## `fmeffects`: Forward vs. centered effects - -As the code in the `mlr3` section makes clear, the [`avg_comparisons()`](reference/comparisons.html) computes the effect of a "centered" change on the outcome. If we want to compute a "Forward Marginal Effect" instead, we can call: - -```{r} -avg_comparisons( - forest, - variables = list("temp" = \(x) data.frame(x, x + 1)), - newdata = bikes) -``` - -This is equivalent to using the `fmeffects` package: - -```{r} -fmeffects::fme( - model = forest, - data = bikes, - target = "count", - feature = "temp", - step.size = 1)$ame -``` - ## Partial Dependence Plots diff --git a/docs/articles/NEWS.html b/docs/articles/NEWS.html index 69a653d05..d18f3b4c5 100644 --- a/docs/articles/NEWS.html +++ b/docs/articles/NEWS.html @@ -292,14 +292,14 @@
  • 25.4.3 Marginal Effects for Location Scale Models
  • -
  • 25.5 effects
  • -
  • 25.6 modelbased
  • -
  • 25.7 ggeffects
  • +
  • 25.5 fmeffects
  • +
  • 25.6 effects
  • +
  • 25.7 modelbased
  • +
  • 25.8 ggeffects
  • @@ -797,7 +798,7 @@

    #> Term Contrast cyl Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 % #> cyl 6 - 4 4 -0.9049 1.63e+00 -0.55506 0.579 0.8 -4.100 2.29e+00 #> cyl 8 - 4 4 -19.5418 4.37e+03 -0.00447 0.996 0.0 -8579.030 8.54e+03 -#> hp dY/dX 4 -0.0326 3.39e-02 -0.96147 0.336 1.6 -0.099 3.38e-02 +#> hp dY/dX 4 -0.0326 3.39e-02 -0.96140 0.336 1.6 -0.099 3.38e-02 #> #> Columns: rowid, term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, cyl, predicted_lo, predicted_hi, predicted, vs, hp #> Type: link @@ -914,22 +915,22 @@

    head(data.frame(mar))
     #>    mpg cyl disp  hp drat    wt  qsec vs am gear carb   fitted se.fitted   dydx_cyl    dydx_hp   dydx_wt Var_dydx_cyl  Var_dydx_hp Var_dydx_wt X_weights X_at_number
    -#> 1 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4 22.82043 0.6876212 -0.9416168 -0.0180381 -3.166973    0.3035104 0.0001410451   0.5484521        NA           1
    -#> 2 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4 22.01285 0.6056817 -0.9416168 -0.0180381 -3.166973    0.3035104 0.0001410451   0.5484521        NA           1
    -#> 3 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1 25.96040 0.7349593 -0.9416168 -0.0180381 -3.166973    0.3035104 0.0001410451   0.5484521        NA           1
    -#> 4 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1 20.93608 0.5800910 -0.9416168 -0.0180381 -3.166973    0.3035104 0.0001410451   0.5484521        NA           1
    -#> 5 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2 17.16780 0.8322986 -0.9416168 -0.0180381 -3.166973    0.3035104 0.0001410451   0.5484521        NA           1
    -#> 6 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1 20.25036 0.6638322 -0.9416168 -0.0180381 -3.166973    0.3035104 0.0001410451   0.5484521        NA           1
    +#> 1 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4 22.82043 0.6876212 -0.9416168 -0.0180381 -3.166973    0.3035074 0.0001410453   0.5484524        NA           1
    +#> 2 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4 22.01285 0.6056817 -0.9416168 -0.0180381 -3.166973    0.3035074 0.0001410453   0.5484524        NA           1
    +#> 3 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1 25.96040 0.7349593 -0.9416168 -0.0180381 -3.166973    0.3035074 0.0001410453   0.5484524        NA           1
    +#> 4 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1 20.93608 0.5800910 -0.9416168 -0.0180381 -3.166973    0.3035074 0.0001410453   0.5484524        NA           1
    +#> 5 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2 17.16780 0.8322986 -0.9416168 -0.0180381 -3.166973    0.3035074 0.0001410453   0.5484524        NA           1
    +#> 6 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1 20.25036 0.6638322 -0.9416168 -0.0180381 -3.166973    0.3035074 0.0001410453   0.5484524        NA           1
     
     head(mfx)
     #> 
     #>  Term Estimate Std. Error     z Pr(>|z|)   S 2.5 % 97.5 %
     #>   cyl   -0.942      0.550 -1.71   0.0871 3.5 -2.02  0.137
     #>   cyl   -0.942      0.550 -1.71   0.0871 3.5 -2.02  0.137
    -#>   cyl   -0.942      0.550 -1.71   0.0871 3.5 -2.02  0.137
    -#>   cyl   -0.942      0.550 -1.71   0.0871 3.5 -2.02  0.137
     #>   cyl   -0.942      0.551 -1.71   0.0875 3.5 -2.02  0.138
     #>   cyl   -0.942      0.550 -1.71   0.0871 3.5 -2.02  0.137
    +#>   cyl   -0.942      0.551 -1.71   0.0873 3.5 -2.02  0.138
    +#>   cyl   -0.942      0.550 -1.71   0.0871 3.5 -2.02  0.137
     #> 
     #> Columns: rowid, term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted, mpg, cyl, hp, wt 
     #> Type:  response
    @@ -941,14 +942,14 @@ 

    mar <- margins(mod, data = data.frame(prediction::mean_or_mode(mtcars)), unit_ses = TRUE)
     data.frame(mar)
     #>        mpg    cyl     disp       hp     drat      wt     qsec     vs      am   gear   carb   fitted se.fitted   dydx_cyl    dydx_hp   dydx_wt Var_dydx_cyl  Var_dydx_hp Var_dydx_wt SE_dydx_cyl SE_dydx_hp SE_dydx_wt X_weights X_at_number
    -#> 1 20.09062 6.1875 230.7219 146.6875 3.596563 3.21725 17.84875 0.4375 0.40625 3.6875 2.8125 20.09062 0.4439832 -0.9416168 -0.0180381 -3.166973    0.3035013 0.0001410453     0.54846   0.5509096 0.01187625  0.7405808        NA           1
    +#> 1 20.09062 6.1875 230.7219 146.6875 3.596563 3.21725 17.84875 0.4375 0.40625 3.6875 2.8125 20.09062 0.4439832 -0.9416168 -0.0180381 -3.166973    0.3034971 0.0001410454     0.54846   0.5509057 0.01187626  0.7405808        NA           1
     
     slopes(mod, newdata = "mean")
     #> 
     #>  Term Estimate Std. Error     z Pr(>|z|)    S   2.5 %   97.5 %
    -#>   cyl   -0.942     0.5510 -1.71   0.0875  3.5 -2.0216  0.13833
    +#>   cyl   -0.942     0.5506 -1.71   0.0873  3.5 -2.0209  0.13763
     #>   hp    -0.018     0.0119 -1.52   0.1290  3.0 -0.0413  0.00525
    -#>   wt    -3.167     0.7406 -4.28   <0.001 15.7 -4.6185 -1.71549
    +#>   wt    -3.167     0.7406 -4.28   <0.001 15.7 -4.6186 -1.71536
     #> 
     #> Columns: rowid, term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted, mpg, cyl, hp, wt 
     #> Type:  response
    @@ -964,13 +965,13 @@

    mar <- margins(mod, at = list(cyl = c(4, 6, 8))) summary(mar) #> factor cyl AME SE z p lower upper -#> cyl 4.0000 0.0381 0.6000 0.0636 0.9493 -1.1378 1.2141 -#> cyl 6.0000 0.0381 0.5999 0.0636 0.9493 -1.1376 1.2139 +#> cyl 4.0000 0.0381 0.5999 0.0636 0.9493 -1.1376 1.2139 +#> cyl 6.0000 0.0381 0.5999 0.0636 0.9493 -1.1376 1.2138 #> cyl 8.0000 0.0381 0.5999 0.0636 0.9493 -1.1376 1.2139 #> hp 4.0000 -0.0878 0.0267 -3.2937 0.0010 -0.1400 -0.0355 #> hp 6.0000 -0.0499 0.0154 -3.2397 0.0012 -0.0800 -0.0197 #> hp 8.0000 -0.0120 0.0108 -1.1065 0.2685 -0.0332 0.0092 -#> wt 4.0000 -3.1198 0.6613 -4.7175 0.0000 -4.4160 -1.8236 +#> wt 4.0000 -3.1198 0.6613 -4.7176 0.0000 -4.4160 -1.8236 #> wt 6.0000 -3.1198 0.6613 -4.7175 0.0000 -4.4160 -1.8236 #> wt 8.0000 -3.1198 0.6613 -4.7175 0.0000 -4.4160 -1.8236 @@ -980,15 +981,15 @@

    newdata = datagridcf(cyl = c(4, 6, 8))) #> #> Term Contrast cyl Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 % -#> cyl mean(dY/dX) 4 0.0381 0.5999 0.0636 0.9493 0.1 -1.1377 1.21401 -#> cyl mean(dY/dX) 6 0.0381 0.5998 0.0636 0.9493 0.1 -1.1375 1.21381 -#> cyl mean(dY/dX) 8 0.0381 0.5999 0.0636 0.9493 0.1 -1.1376 1.21389 -#> hp mean(dY/dX) 4 -0.0878 0.0267 -3.2936 <0.001 10.0 -0.1400 -0.03554 -#> hp mean(dY/dX) 6 -0.0499 0.0154 -3.2397 0.0012 9.7 -0.0800 -0.01970 +#> cyl mean(dY/dX) 4 0.0381 0.6000 0.0636 0.9493 0.1 -1.1377 1.21402 +#> cyl mean(dY/dX) 6 0.0381 0.5998 0.0636 0.9493 0.1 -1.1375 1.21378 +#> cyl mean(dY/dX) 8 0.0381 0.5999 0.0636 0.9493 0.1 -1.1377 1.21396 +#> hp mean(dY/dX) 4 -0.0878 0.0267 -3.2937 <0.001 10.0 -0.1400 -0.03555 +#> hp mean(dY/dX) 6 -0.0499 0.0154 -3.2398 0.0012 9.7 -0.0800 -0.01970 #> hp mean(dY/dX) 8 -0.0120 0.0108 -1.1065 0.2685 1.9 -0.0332 0.00923 -#> wt mean(dY/dX) 4 -3.1198 0.6613 -4.7175 <0.001 18.7 -4.4160 -1.82362 -#> wt mean(dY/dX) 6 -3.1198 0.6613 -4.7174 <0.001 18.7 -4.4160 -1.82362 -#> wt mean(dY/dX) 8 -3.1198 0.6613 -4.7174 <0.001 18.7 -4.4160 -1.82362 +#> wt mean(dY/dX) 4 -3.1198 0.6613 -4.7176 <0.001 18.7 -4.4160 -1.82366 +#> wt mean(dY/dX) 6 -3.1198 0.6613 -4.7176 <0.001 18.7 -4.4160 -1.82366 +#> wt mean(dY/dX) 8 -3.1198 0.6613 -4.7176 <0.001 18.7 -4.4160 -1.82366 #> #> Columns: term, contrast, cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted #> Type: response

    @@ -1056,9 +1057,9 @@

    avg_slopes(mod) #> #> Term Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 % -#> cyl -0.942 0.5506 -1.71 0.0872 3.5 -2.0208 0.13753 +#> cyl -0.942 0.5507 -1.71 0.0873 3.5 -2.0209 0.13770 #> hp -0.018 0.0119 -1.52 0.1288 3.0 -0.0413 0.00524 -#> wt -3.167 0.7406 -4.28 <0.001 15.7 -4.6185 -1.71546 +#> wt -3.167 0.7406 -4.28 <0.001 15.7 -4.6185 -1.71549 #> #> Columns: term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high #> Type: response

    @@ -1116,7 +1117,7 @@

    #> #> Term Contrast cyl Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 % #> hp mean(dY/dX) 4 -0.0995 0.0349 -2.853 0.00433 7.9 -0.1678 -0.0311 -#> hp mean(dY/dX) 6 -0.0214 0.0388 -0.551 0.58187 0.8 -0.0975 0.0547 +#> hp mean(dY/dX) 6 -0.0214 0.0388 -0.551 0.58188 0.8 -0.0975 0.0547 #> hp mean(dY/dX) 8 -0.0134 0.0125 -1.074 0.28278 1.8 -0.0380 0.0111 #> #> Columns: term, contrast, cyl, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted @@ -1141,7 +1142,7 @@

    J_mean <- aggregate(J, by = list(mfx$cyl), FUN = mean) J_mean <- as.matrix(J_mean[, 2:ncol(J_mean)]) sqrt(diag(J_mean %*% vcov(mod) %*% t(J_mean))) -#> [1] 0.03486654 0.03882093 0.01251377 +#> [1] 0.03486633 0.03882199 0.01251382

    25.3.3 Average Counterfactual Adjusted Predictions

    @@ -1466,8 +1467,63 @@

    #> Columns: term, contrast, estimate, conf.low, conf.high #> Type: response -

    -25.5 effects +

    +25.5 fmeffects +

    +

    The fmeffects package is described as follows:

    +
    +

    fmeffects: Model-Agnostic Interpretations with Forward Marginal Effects. Create local, regional, and global explanations for any machine learning model with forward marginal effects. You provide a model and data, and ‘fmeffects’ computes feature effects. The package is based on the theory in: C. A. Scholbeck, G. Casalicchio, C. Molnar, B. Bischl, and C. Heumann (2022)

    +
    +

    As the name says, this package is focused on “forward marginal effects” in the context of machine learning models estimated using the mlr3 or tidymodels frameworks. Since version 0.16.0, marginaleffects also supports these machine learning frameworks, and it covers a superset of the fmeffects functionality. Consider a random forest model trained on the bikes data:

    +
    +
    library("mlr3verse")
    +library("fmeffects")
    +data("bikes", package = "fmeffects")
    +task <- as_task_regr(x = bikes, id = "bikes", target = "count")
    +forest <- lrn("regr.ranger")$train(task)
    +
    +

    Now, we use the avg_comparisons() function to compute centered marginal effects:

    +
    +
    avg_comparisons(forest, variables = "temp", newdata = bikes)
    +#> 
    +#>  Term Contrast Estimate
    +#>  temp       +1     3.48
    +#> 
    +#> Columns: term, contrast, estimate 
    +#> Type:  response
    +
    +

    We call this quantity “centered” because it represents the average effect of a change of 1 unit in temp about the observed value, that is, a change from 0.5 below to 0.5 above:

    +
    +
    lo <- transform(bikes, temp = temp - 0.5)
    +hi <- transform(bikes, temp = temp + 0.5)
    +mean(predict(forest, newdata = hi) - predict(forest, newdata = lo))
    +#> [1] 3.477967
    +
    +

    As described in the comparisons() vignette, it is easy to estimate “backward”, “centered” or “forward” differences by supplying an appropriate function to the variables argument. For example, here is how to compute “forward” marginal effects:

    +
    +
    avg_comparisons(
    +    forest,
    +    variables = list("temp" = \(x) data.frame(x, x + 1)),
    +    newdata = bikes)
    +#> 
    +#>  Term Contrast Estimate
    +#>  temp   custom     2.34
    +#> 
    +#> Columns: term, contrast, estimate 
    +#> Type:  response
    +
    +

    This is equivalent to the key quantity reported by the fmeffects package:

    +
    +
    fmeffects::fme(
    +    model = forest,
    +    data = bikes,
    +    target = "count",
    +    feature = "temp",
    +    step.size = 1)$ame
    +#> [1] 2.340716
    +
    +

    +25.6 effects

    The effects package was created by John Fox and colleagues.

    -25.6 modelbased +

    +25.7 modelbased

    The modelbased package is developed by the easystats team.

    This section is incomplete; contributions are welcome.

    -25.7 ggeffects +

    +25.8 ggeffects

    The ggeffects package is developed by Daniel Lüdecke.

    This section is incomplete; contributions are welcome.

    diff --git a/docs/articles/bootstrap.html b/docs/articles/bootstrap.html index 1bb87261e..5d7cbbae1 100644 --- a/docs/articles/bootstrap.html +++ b/docs/articles/bootstrap.html @@ -307,14 +307,14 @@

    diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-16-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-16-1.png index 62f3e56b1..bb73b86f1 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-16-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-16-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-21-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-21-1.png index 359cbe276..e0c220527 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-21-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-21-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-23-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-23-1.png index 22eb1384c..c74b66e99 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-23-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-23-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-27-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-27-1.png index cc379c25c..428ea3a12 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-27-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-27-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-28-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-28-1.png index 7539da8a4..609137f49 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-28-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-28-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-30-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-30-1.png index a8b0d400f..6cf16da83 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-30-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-30-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-32-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-32-1.png index 4031b1f49..784ed1616 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-32-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-32-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-33-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-33-1.png index bcada68d2..befdfe8ea 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-33-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-33-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-34-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-34-1.png index a6bc78c93..2edf353f1 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-34-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-34-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-35-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-35-1.png index a66c70f30..b44ac19e2 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-35-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-35-1.png differ diff --git a/docs/articles/brms_files/figure-html/unnamed-chunk-66-1.png b/docs/articles/brms_files/figure-html/unnamed-chunk-66-1.png index c5ab3a774..adee9f13d 100644 Binary files a/docs/articles/brms_files/figure-html/unnamed-chunk-66-1.png and b/docs/articles/brms_files/figure-html/unnamed-chunk-66-1.png differ diff --git a/docs/articles/categorical.html b/docs/articles/categorical.html index c3ecb1d18..710a37e21 100644 --- a/docs/articles/categorical.html +++ b/docs/articles/categorical.html @@ -307,14 +307,14 @@

    -20.2 mlr3 + +

    +21.2 mlr3

    mlr3 is a machine learning framework for R. It makes it possible for users to train a wide range of models, including linear models, random forests, gradient boosting machines, and neural networks.

    In this example, we use the bikes dataset supplied by the fmeffects package to train a random forest model predicting the number of bikes rented per hour. We then use marginaleffects to interpret the results of the model.

    -
    data("bikes", package = "fmeffects")
    +
    data("bikes", package = "fmeffects")
     
     task <- as_task_regr(x = bikes, id = "bikes", target = "count")
     forest <- lrn("regr.ranger")$train(task)

    As described in other vignettes, we can use the avg_comparisons() function to compute the average change in predicted outcome that is associated with a change in each feature:

    -
    avg_comparisons(forest, newdata = bikes)
    +
    avg_comparisons(forest, newdata = bikes)
    
            Term      Contrast Estimate
      count      +1               0.000
    - holiday    False - True    14.001
    - humidity   +1             -23.438
    - month      +1               3.895
    - season     spring - fall  -29.351
    - season     summer - fall   -7.390
    - season     winter - fall    4.739
    - temp       +1               3.603
    - weather    misty - clear   -7.934
    - weather    rain - clear   -59.467
    - weekday    Fri - Sun       69.384
    - weekday    Mon - Sun       77.766
    - weekday    Sat - Sun       18.296
    - weekday    Thu - Sun       85.197
    - weekday    Tue - Sun       83.673
    - weekday    Wed - Sun       85.878
    - windspeed  +1               0.294
    - workingday False - True  -188.588
    - year       1 - 0           98.440
    + holiday    False - True    15.015
    + humidity   +1             -22.587
    + month      +1               4.071
    + season     spring - fall  -34.097
    + season     summer - fall   -9.406
    + season     winter - fall    2.032
    + temp       +1               3.642
    + weather    misty - clear   -7.974
    + weather    rain - clear   -58.013
    + weekday    Fri - Sun       75.829
    + weekday    Mon - Sun       84.463
    + weekday    Sat - Sun       27.949
    + weekday    Thu - Sun       91.796
    + weekday    Tue - Sun       90.371
    + weekday    Wed - Sun       92.699
    + windspeed  +1               0.203
    + workingday False - True  -187.681
    + year       1 - 0           97.913
     
     Columns: term, contrast, estimate 
     Type:  response 
    -

    These results are easy to interpret: An increase of 1 degree Celsius in the temperature is associated with an increase of 3.603 bikes rented per hour.

    +

    These results are easy to interpret: An increase of 1 degree Celsius in the temperature is associated with an increase of 3.642 bikes rented per hour.

    We could obtain the same result manually as follows:

    -
    lo <- transform(bikes, temp = temp - 0.5)
    +
    lo <- transform(bikes, temp = temp - 0.5)
     hi <- transform(bikes, temp = temp + 0.5)
     mean(predict(forest, newdata = hi) - predict(forest, newdata = lo))
    -
    [1] 3.603054
    +
    [1] 3.64244
    -

    -20.3 Simultaneous changes

    +

    +21.3 Simultaneous changes

    With marginaleffects::avg_comparisons(), we can also compute the average effect of a simultaneous change in multiple predictors, using the variables and cross arguments. In this example, we see what happens (on average) to the predicted outcome when the temp, season, and weather predictors all change together:

    -
    avg_comparisons(
    +
    avg_comparisons(
         forest,
         variables = c("temp", "season", "weather"),
         cross = TRUE,
    @@ -875,50 +1319,21 @@ 

    
      Estimate     C: season C: temp    C: weather
    -  -32.469 spring - fall      +1 misty - clear
    -  -76.390 spring - fall      +1 rain - clear 
    -  -11.574 summer - fall      +1 misty - clear
    -  -60.974 summer - fall      +1 rain - clear 
    -    0.334 winter - fall      +1 misty - clear
    -  -53.340 winter - fall      +1 rain - clear 
    +   -37.51 spring - fall      +1 misty - clear
    +   -79.93 spring - fall      +1 rain - clear 
    +   -13.97 summer - fall      +1 misty - clear
    +   -61.82 summer - fall      +1 rain - clear 
    +    -2.65 winter - fall      +1 misty - clear
    +   -54.58 winter - fall      +1 rain - clear 
     
     Columns: term, contrast_season, contrast_temp, contrast_weather, estimate 
     Type:  response 

    -

    -20.4 fmeffects: Forward vs. centered effects

    -

    As the code in the mlr3 section makes clear, the avg_comparisons() computes the effect of a “centered” change on the outcome. If we want to compute a “Forward Marginal Effect” instead, we can call:

    -
    -
    avg_comparisons(
    -    forest,
    -    variables = list("temp" = \(x) data.frame(x, x + 1)),
    -    newdata = bikes)
    -
    -
    
    - Term Contrast Estimate
    - temp   custom     2.39
    -
    -Columns: term, contrast, estimate 
    -Type:  response 
    -
    -
    -

    This is equivalent to using the fmeffects package:

    -
    -
    fmeffects::fme(
    -    model = forest,
    -    data = bikes,
    -    target = "count",
    -    feature = "temp",
    -    step.size = 1)$ame 
    -
    -
    [1] 2.386648
    -
    -
    -

    -20.5 Partial Dependence Plots

    +

    +21.4 Partial Dependence Plots

    -
    # https://stackoverflow.com/questions/67634344/r-partial-dependence-plots-from-workflow
    +
    # https://stackoverflow.com/questions/67634344/r-partial-dependence-plots-from-workflow
     library("tidymodels")
     library("marginaleffects")
     data(ames, package = "modeldata")
    @@ -949,12 +1364,12 @@ 

    by = c("Gr_Liv_Area", "Bldg_Type")) + labs(x = "Living Area", y = "Predicted log10(Sale Price)", color = "Building Type")

    -

    +

    We can replicate this plot using the DALEXtra package:

    -
    library("DALEXtra")
    +
    library("DALEXtra")
     pdp_rf <- explain_tidymodels(
         m,
         data = dplyr::select(dat, -Sale_Price),
    @@ -967,40 +1382,40 @@ 

    groups = "Bldg_Type") plot(pdp_rf)

    -

    +

    Note that marginaleffects and DALEXtra plots are not exactly identical because the randomly sampled profiles are not the same. You can try the same procedure without sampling — or equivalently with N=2930 — to see a perfect equivalence.

    -

    -20.6 Other Plots

    +

    +21.5 Other Plots

    We can plot the results using the standard marginaleffects helpers. For example, to plot predictions, we can do:

    -
    library(mlr3verse)
    +
    library(mlr3verse)
     data("bikes", package = "fmeffects")
     task <- as_task_regr(x = bikes, id = "bikes", target = "count")
     forest <- lrn("regr.ranger")$train(task)
     
     plot_predictions(forest, condition = "temp", newdata = bikes)
    -

    +

    As documented in ?plot_predictions, using condition="temp" is equivalent to creating an equally-spaced grid of temp values, and holding all other predictors at their means or modes. In other words, it is equivalent to:

    -
    d <- datagrid(temp = seq(min(bikes$temp), max(bikes$temp), length.out = 100), newdata = bikes)
    +
    d <- datagrid(temp = seq(min(bikes$temp), max(bikes$temp), length.out = 100), newdata = bikes)
     p <- predict(forest, newdata = d)
     plot(d$temp, p, type = "l")

    Alternatively, we could plot “marginal” predictions, where replicate the full dataset once for every value of temp, and then average the predicted values over each value of the x-axis:

    -
    plot_predictions(forest, by = "temp", newdata = bikes)
    +
    plot_predictions(forest, by = "temp", newdata = bikes)
    -

    +

    Of course, we can customize the plot using all the standard ggplot2 functions:

    -
    plot_predictions(forest, by = "temp", newdata = d) +
    +
    plot_predictions(forest, by = "temp", newdata = d) +
         geom_point(data = bikes, aes(x = temp, y = count), alpha = 0.1) +
         geom_smooth(data = bikes, aes(x = temp, y = count), se = FALSE, color = "orange") +
         labs(x = "Temperature (Celsius)", y = "Predicted number of bikes rented per hour",
    @@ -1010,7 +1425,7 @@ 

    `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

    -

    +

    @@ -1455,13 +1870,13 @@

    } }); diff --git a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-10-1.png b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-10-1.png index eaccbecf8..3436ca093 100644 Binary files a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-10-1.png and b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-10-1.png differ diff --git a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-11-1.png b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-11-1.png index 156e7514c..ade489592 100644 Binary files a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-11-1.png and b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-11-1.png differ diff --git a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-12-1.png b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-12-1.png index a9590e555..3b8469596 100644 Binary files a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-12-1.png and b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-12-1.png differ diff --git a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-15-1.png b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-15-1.png index 83fc61212..e41a36b7f 100644 Binary files a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-15-1.png and b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-15-1.png differ diff --git a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-16-1.png b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-16-1.png index 193d900d6..4daaac0ef 100644 Binary files a/docs/articles/machine_learning_files/figure-html/unnamed-chunk-16-1.png and b/docs/articles/machine_learning_files/figure-html/unnamed-chunk-16-1.png differ diff --git a/docs/articles/marginaleffects.html b/docs/articles/marginaleffects.html index 92f09cfcc..d4ddc781e 100644 --- a/docs/articles/marginaleffects.html +++ b/docs/articles/marginaleffects.html @@ -307,14 +307,14 @@

    @@ -962,10 +962,10 @@

    │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │ ╞══════╪══════════╪═══════════╪═══════════╪═══╪══════════╪═══════╪═══════════╪══════════╡ -│ wt ┆ dY/dX ┆ -6.61 ┆ 1.87 ┆ … ┆ 0.00164 ┆ 9.25 ┆ -10.5 ┆ -2.76 │ │ wt ┆ dY/dX ┆ -6.61 ┆ 1.87 ┆ … ┆ 0.00165 ┆ 9.25 ┆ -10.5 ┆ -2.76 │ -│ wt ┆ dY/dX ┆ -7.16 ┆ 1.8 ┆ … ┆ 0.000558 ┆ 10.8 ┆ -10.9 ┆ -3.45 │ -│ wt ┆ dY/dX ┆ -3.21 ┆ 2.01 ┆ … ┆ 0.123 ┆ 3.02 ┆ -7.35 ┆ 0.939 │ +│ wt ┆ dY/dX ┆ -6.61 ┆ 1.87 ┆ … ┆ 0.00166 ┆ 9.24 ┆ -10.5 ┆ -2.76 │ +│ wt ┆ dY/dX ┆ -7.16 ┆ 1.8 ┆ … ┆ 0.000559 ┆ 10.8 ┆ -10.9 ┆ -3.45 │ +│ wt ┆ dY/dX ┆ -3.21 ┆ 2.01 ┆ … ┆ 0.124 ┆ 3.02 ┆ -7.35 ┆ 0.94 │ │ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │ │ am ┆ dY/dX ┆ 2.11e+04 ┆ 2.29e+04 ┆ … ┆ 0.367 ┆ 1.45 ┆ -2.62e+04 ┆ 6.83e+04 │ │ am ┆ dY/dX ┆ 8.95e+03 ┆ 1.64e+04 ┆ … ┆ 0.591 ┆ 0.758 ┆ -2.5e+04 ┆ 4.29e+04 │ @@ -1116,8 +1116,8 @@

    
      Term am Estimate Std. Error     z Pr(>|z|)   S 2.5 % 97.5 %
    -   wt  0    -2.68       1.42 -1.89   0.0593 4.1 -5.46  0.106
    -   wt  1    -5.43       2.15 -2.52   0.0116 6.4 -9.65 -1.214
    +   wt  0    -2.68       1.42 -1.89   0.0593 4.1 -5.46  0.105
    +   wt  1    -5.43       2.15 -2.52   0.0116 6.4 -9.65 -1.213
     
     Columns: rowid, term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, am, predicted_lo, predicted_hi, predicted, mpg, hp, wt 
     Type:  response 
    @@ -1139,7 +1139,7 @@

    │ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │ ╞══════╪══════════╪══════════╪═══════════╪═══╪═════════╪══════╪═══════╪════════╡ │ wt ┆ dY/dX ┆ -2.68 ┆ 1.42 ┆ … ┆ 0.072 ┆ 3.8 ┆ -5.61 ┆ 0.258 │ -│ wt ┆ dY/dX ┆ -5.43 ┆ 2.15 ┆ … ┆ 0.0186 ┆ 5.75 ┆ -9.87 ┆ -0.993 │ +│ wt ┆ dY/dX ┆ -5.43 ┆ 2.15 ┆ … ┆ 0.0187 ┆ 5.74 ┆ -9.88 ┆ -0.986 │ └──────┴──────────┴──────────┴───────────┴───┴─────────┴──────┴───────┴────────┘ Columns: rowid, term, contrast, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high, predicted, predicted_lo, predicted_hi, am, rownames, mpg, cyl, disp, hp, drat, wt, qsec, vs, gear, carb @@ -1235,7 +1235,7 @@

    np.mean(mod.predict())
    -
    20.090624999999992
    +
    20.090625000000014

    @@ -1257,7 +1257,7 @@

    am mean(1) - mean(0) 0 -1.3830 2.5250 -0.548 0.58388 0.8 -6.3319 3.56589 am mean(1) - mean(0) 1 1.9029 2.3086 0.824 0.40980 1.3 -2.6219 6.42773 hp mean(+1) 0 -0.0343 0.0159 -2.160 0.03079 5.0 -0.0654 -0.00317 - hp mean(+1) 1 -0.0436 0.0213 -2.050 0.04039 4.6 -0.0854 -0.00191 + hp mean(+1) 1 -0.0436 0.0213 -2.050 0.04038 4.6 -0.0854 -0.00191 wt mean(+1) 0 -2.4799 1.2316 -2.014 0.04406 4.5 -4.8939 -0.06595 wt mean(+1) 1 -6.0718 1.9762 -3.072 0.00212 8.9 -9.9451 -2.19846 @@ -1480,8 +1480,8 @@

    
      Term qsec Estimate Std. Error    z Pr(>|z|)   S  2.5 % 97.5 %
    - drat 14.5     5.22       3.80 1.38   0.1690 2.6 -2.221   12.7
    - drat 22.9    10.24       5.15 1.99   0.0469 4.4  0.142   20.3
    + drat 14.5     5.22       3.79 1.38   0.1682 2.6 -2.206   12.7
    + drat 22.9    10.24       5.16 1.98   0.0472 4.4  0.127   20.4
     
     Columns: rowid, term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, qsec, predicted_lo, predicted_hi, predicted, mpg, drat 
     Type:  response 
    @@ -1502,8 +1502,8 @@

    │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │ ╞══════╪══════════╪══════════╪═══════════╪═══╪═════════╪══════╪════════╪═══════╡ -│ drat ┆ dY/dX ┆ 5.22 ┆ 3.8 ┆ … ┆ 0.18 ┆ 2.47 ┆ -2.56 ┆ 13 │ -│ drat ┆ dY/dX ┆ 10.2 ┆ 5.16 ┆ … ┆ 0.0573 ┆ 4.13 ┆ -0.338 ┆ 20.8 │ +│ drat ┆ dY/dX ┆ 5.22 ┆ 3.81 ┆ … ┆ 0.181 ┆ 2.47 ┆ -2.57 ┆ 13 │ +│ drat ┆ dY/dX ┆ 10.2 ┆ 5.16 ┆ … ┆ 0.057 ┆ 4.13 ┆ -0.328 ┆ 20.8 │ └──────┴──────────┴──────────┴───────────┴───┴─────────┴──────┴────────┴───────┘ Columns: rowid, term, contrast, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high, predicted, predicted_lo, predicted_hi, qsec, rownames, mpg, cyl, disp, hp, drat, wt, vs, am, gear, carb @@ -1551,7 +1551,7 @@

    │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ str ┆ str ┆ str ┆ str ┆ str │ ╞═══════╪══════════╪═══════════╪════════╪═════════╪═══════╪═══════╪═══════╡ -│ b1=b2 ┆ -5.02 ┆ 8.53 ┆ -0.588 ┆ 0.561 ┆ 0.834 ┆ -22.5 ┆ 12.5 │ +│ b1=b2 ┆ -5.02 ┆ 8.53 ┆ -0.588 ┆ 0.561 ┆ 0.833 ┆ -22.5 ┆ 12.5 │ └───────┴──────────┴───────────┴────────┴─────────┴───────┴───────┴───────┘ Columns: term, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high @@ -1574,7 +1574,7 @@

    
      Term Estimate Std. Error    z Pr(>|z|)    S 2.5 % 97.5 %
      drat     7.22      1.365 5.29  < 0.001 23.0 4.549   9.90
    - qsec     1.12      0.433 2.60  0.00942  6.7 0.276   1.97
    + qsec     1.12      0.433 2.59  0.00947  6.7 0.275   1.97
     
     Columns: term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high 
     Type:  response 
    @@ -1601,7 +1601,7 @@

    │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │ ╞══════╪═════════════╪══════════╪═══════════╪═══╪══════════╪══════╪═══════╪═══════╡ -│ qsec ┆ mean(dY/dX) ┆ 1.12 ┆ 0.432 ┆ … ┆ 0.0147 ┆ 6.09 ┆ 0.239 ┆ 2.01 │ +│ qsec ┆ mean(dY/dX) ┆ 1.12 ┆ 0.435 ┆ … ┆ 0.0152 ┆ 6.04 ┆ 0.234 ┆ 2.01 │ │ drat ┆ mean(dY/dX) ┆ 7.22 ┆ 1.37 ┆ … ┆ 1.25e-05 ┆ 16.3 ┆ 4.43 ┆ 10 │ └──────┴─────────────┴──────────┴───────────┴───┴──────────┴──────┴───────┴───────┘ @@ -1639,7 +1639,7 @@

    
      Term Estimate Std. Error    z Pr(>|z|)    S 2.5 % 97.5 % p (NonSup) p (NonInf) p (Equiv)
      drat     7.22      1.365 5.29  < 0.001 23.0 4.549   9.90     0.9999     <0.001    0.9999
    - qsec     1.12      0.433 2.60  0.00942  6.7 0.276   1.97     0.0215     <0.001    0.0215
    + qsec     1.12      0.433 2.59  0.00947  6.7 0.275   1.97     0.0216     <0.001    0.0216
     
     Columns: term, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, statistic.noninf, statistic.nonsup, p.value.noninf, p.value.nonsup, p.value.equiv 
     Type:  response 
    @@ -1657,8 +1657,8 @@

    │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │ ╞══════╪═════════════╪══════════╪═══════════╪═══╪══════════╪══════╪═══════╪═══════╡ +│ qsec ┆ mean(dY/dX) ┆ 1.12 ┆ 0.435 ┆ … ┆ 0.0152 ┆ 6.04 ┆ 0.234 ┆ 2.01 │ │ drat ┆ mean(dY/dX) ┆ 7.22 ┆ 1.37 ┆ … ┆ 1.25e-05 ┆ 16.3 ┆ 4.43 ┆ 10 │ -│ qsec ┆ mean(dY/dX) ┆ 1.12 ┆ 0.432 ┆ … ┆ 0.0147 ┆ 6.09 ┆ 0.239 ┆ 2.01 │ └──────┴─────────────┴──────────┴───────────┴───┴──────────┴──────┴───────┴───────┘ Columns: term, contrast, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high, statistic_noninf, statistic_nonsup, p_value_noninf, p_value_nonsup, p_value_equiv diff --git a/docs/articles/marginalmeans.html b/docs/articles/marginalmeans.html index 18c0645a0..00fbcfc10 100644 --- a/docs/articles/marginalmeans.html +++ b/docs/articles/marginalmeans.html @@ -307,14 +307,14 @@