forked from tidymodels/TMwR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
09-judging-model-effectiveness.Rmd
340 lines (237 loc) · 19.3 KB
/
09-judging-model-effectiveness.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
```{r performance-setup, include = FALSE}
knitr::opts_chunk$set(fig.path = "figures/")
library(tidymodels)
library(kableExtra)
tidymodels_prefer()
source("ames_snippets.R")
load("RData/lm_fit.RData")
data(ad_data)
set.seed(245)
ad_folds <- vfold_cv(ad_data, repeats = 5)
```
# Judging model effectiveness {#performance}
Once we have a model, we need to know how well it works. A quantitative approach for estimating effectiveness allows us to understand the model, to compare different models, or to tweak the model to improve performance. Our focus in tidymodels is on _empirical validation_; this usually means using data that were not used to create the model as the substrate to measure effectiveness.
:::rmdwarning
The best approach to empirical validation involves using _resampling_ methods that will be introduced in Chapter \@ref(resampling). In this chapter, we will use the test set for illustration purposes and to motivate the need for empirical validation. Keep in mind that the test set can only be used once, as explained in Section \@ref(splitting-methods).
:::
The choice of which metrics to examine can be critical. In later chapters, certain model parameters will be empirically optimized and a primary performance metric will be used to choose the best _sub-model_. Choosing the wrong method can easily result in unintended consequences. For example, two common metrics for regression models are the root mean squared error (RMSE) and the coefficient of determination (a.k.a. $R^2$). The former measures _accuracy_ while the latter measures _correlation_. These are not necessarily the same thing. This figure demonstrates the difference between the two:
```{r performance-reg-metrics, echo = FALSE}
set.seed(234)
n <- 200
obs <- runif(n, min = 2, max = 20)
reg_ex <-
tibble(
observed = c(obs, obs),
predicted = c(obs + rnorm(n, sd = 1.5), 5 + .5 * obs + rnorm(n, sd = .5)),
approach = rep(c("RMSE optimized", "R^2 optimized"), each = n)
) %>%
mutate(approach = factor(approach, levels = c("RMSE optimized", "R^2 optimized")))
ggplot(reg_ex, aes(x = observed, y = predicted)) +
geom_abline(lty = 2) +
geom_point(alpha = 0.5) +
coord_obs_pred() +
facet_wrap(~ approach)
```
A model optimized for RMSE has more variability but has relatively uniform accuracy across the range of the outcome. The right panel shows that there is a tighter correlation between the observed and predicted values but this model performs poorly in the tails.
This chapter will largely focus on the `r pkg(yardstick)` package. Before illustrating syntax, let's explore whether empirical validation using performance metrics is worthwhile when a model is focused on inference rather than prediction.
## Performance metrics and inference
```{r performance ad-model, include = FALSE}
ad_mod <- logistic_reg() %>% set_engine("glm")
full_model_fit <-
ad_mod %>%
fit(Class ~ (Genotype + male + age)^3, , data = ad_data)
full_model_fit %>% pluck("fit")
two_way_fit <-
ad_mod %>%
fit(Class ~ (Genotype + male + age)^2, data = ad_data)
three_factor_test <-
anova(
full_model_fit %>% pluck("fit"),
two_way_fit %>% pluck("fit"),
test = "LRT"
)
main_effects_fit <-
ad_mod %>%
fit(Class ~ Genotype + male + age, data = ad_data)
two_factor_test <-
anova(
two_way_fit %>% pluck("fit"),
main_effects_fit %>% pluck("fit"),
test = "LRT"
)
two_factor_rs <-
ad_mod %>%
fit_resamples(Class ~ (Genotype + male + age)^2, ad_folds)
two_factor_res <-
collect_metrics(two_factor_rs) %>%
filter(.metric == "accuracy") %>%
pull(mean)
```
The effectiveness of any given model depends on how the model will be used. An inferential model is used primarily to understand relationships, and typically is discussed with a strong focus on the choice (and validity) of probabilistic distributions and other generative qualities that define the model. For a model used primarily for prediction, by contrast, predictive strength is primary and concerns about underlying statistical qualities may be less important. Predictive strength is usually focused on how close our predictions come to the observed data, i.e., fidelity of the model predictions to the actual results. This chapter focuses on functions that can be used to measure predictive strength. However, our advice for those developing inferential models is to use these techniques _even when the model will not be used with the primary goal of prediction_.
A longstanding issue with the practice of inferential statistics is that, with a focus purely on inference, it is difficult to assess the credibility of a model. For example, consider the Alzheimer's disease data from @CraigSchapiro when `r nrow(ad_data)` patients were studied to determine the factors that influence cognitive impairment. An analysis might take the known risk factors and build a logistic regression model where the outcome is binary (impaired/non-impaired). Let's consider predictors for age, sex, and the Apolipoprotein E genotype. The latter is a categorical variable with the six possible combinations of the three main variants of this gene. Apolipoprotein E is known to have an association with dementia [@Kim:2009p4370].
A superficial, but not uncommon, approach to this analysis would be to fit a large model with main effects and interactions, then use statistical tests to find the minimal set of model terms that are statistically significant at some pre-defined level. If a full model with the three factors and their two- and three-way interactions were used, an initial phase would be to test the interactions using sequential likelihood ratio tests [@HosmerLemeshow].
* When comparing the model with all two-way interactions to one with the additional three-way interaction, the likelihood ratio tests produces a p-value of `r three_factor_test[2, "Pr(>Chi)"]`. This implies that there is no evidence that the `r abs(three_factor_test[2, "Df"])` additional model terms associated with the three-way interaction explain enough of the variation in the data to keep them in the model.
* Next, the two-way interactions are similarly evaluated against the model with no interactions. The p-value here is `r two_factor_test[2, "Pr(>Chi)"]`. This is somewhat borderline, but, given the small sample size, it would be prudent to conclude that there is evidence that some of the `r abs(two_factor_test[2, "Df"])` possible two-way interactions are important to the model.
* From here, we would build some explanation of the results. The interactions would be particularly important to discuss since they may spark interesting physiological or neurological hypotheses to be explored further.
While shallow, this analysis strategy is common in practice as well as in the literature. This is especially true if the practitioner has limited formal training in data analysis.
One missing piece of information in this approach is how closely this model fits the actual data. Using resampling methods, discussed in Chapter \@ref(resampling), we can estimate the accuracy of this model to be about `r round(two_factor_res * 100, 1)`%. Accuracy is often a poor measure of model performance; we use it here because it is commonly understood. If the model has `r round(two_factor_res * 100, 1)`% fidelity to the data, should we trust the conclusions produced by the model? We might think so until we realize that the baseline rate of non-impaired patients in the data is `r round(mean(ad_data$Class == "Control") * 100, 1)`%. This means that, despite our statistical analysis, the two-factor model appears to be _only `r round((two_factor_res - mean(ad_data$Class == "Control")) * 100, 1)`% better than a simple heuristic that always predicts patients to be unimpaired_, irregardless of the observed data.
```{block, type = "rmdnote"}
The point of this analysis is to demonstrate the idea that **optimization of statistical characteristics of the model does not imply that the model fits the data well.** Even for purely inferential models, some measure of fidelity to the data should accompany the inferential results. Using this, the consumers of the analyses can calibrate their expectations of the results of the statistical analysis.
```
In the remainder of this chapter, general approaches for evaluating models via empirical validation are discussed. These approaches are grouped by the nature of the outcome data: purely numeric, binary classes, and three or more class levels.
## Regression metrics
Recall from Section \@ref(parsnip-predictions) that tidymodels prediction functions produce tibbles with columns for the predicted values. These columns have consistent names, and the functions in the `r pkg(yardstick)` package that produce performance metrics have consistent interfaces. The functions are data frame-based, as opposed to vector-based, with the general syntax of:
```r
function(data, truth, ...)
```
where `data` is a data frame or tibble and `truth` is the column with the observed outcome values. The ellipses or other arguments are used to specify the column(s) containing the predictions.
To illustrate, let's take the model from Section \@ref(recipes-summary). The `lm_wflow_fit` object was a linear regression model whose predictor set was supplemented with an interaction and spline functions for longitude and latitude. It was created from a training set (named `ames_train`). Although we do not advise using the test set at this juncture of the modeling process, it will be used to illustrate functionality and syntax. The data frame `ames_test` consists of `r nrow(ames_test)` properties. To start, let's produce predictions:
```{r performance-predict-ames}
ames_test_res <- predict(lm_fit, new_data = ames_test %>% select(-Sale_Price))
ames_test_res
```
The predicted numeric outcome from the regression model is named `.pred`. Let's match the predicted values with their corresponding observed outcome values:
```{r performance-ames-outcome}
ames_test_res <- bind_cols(ames_test_res, ames_test %>% select(Sale_Price))
ames_test_res
```
Note that both the predicted and observed outcomes are in log10 units. It is best practice to analyze the predictions on the transformed scale (if one were used) even if the predictions are reported using the original units.
Let's plot the data before computing metrics:
```{r performance-ames-plot}
ggplot(ames_test_res, aes(x = Sale_Price, y = .pred)) +
# Create a diagonal line:
geom_abline(lty = 2) +
geom_point(alpha = 0.5) +
labs(y = "Predicted Sale Price (log10)", x = "Sale Price (log10)") +
# Scale and size the x- and y-axis uniformly:
coord_obs_pred()
```
There is one property that is substantially over-predicted.
Let's compute the root mean squared error for this model using the `rmse()` function:
```{r performance-ames-rmse}
rmse(ames_test_res, truth = Sale_Price, estimate = .pred)
```
The output above shows the standard format of the output of `r pkg(yardstick)` functions. Metrics for numeric outcomes usually have a value of "standard" for the `.estimator` column. Examples with different values for this column are shown in the next sections.
To compute multiple metrics at once, we can create a _metric set_. Let's add $R^2$ and the mean absolute error:
```{r performance-metric-set}
ames_metrics <- metric_set(rmse, rsq, mae)
ames_metrics(ames_test_res, truth = Sale_Price, estimate = .pred)
```
This tidy data format stacks the metrics vertically.
:::rmdwarning
The `r pkg(yardstick)` package does _not_ contain a function for adjusted $R^2$. This commonly used modification of the coefficient of determination is needed when the same data used to fit the model are used to evaluate the model. This metric is not fully supported in tidymodels because it is always a better approach to compute performance on a separate data set than the one used to fit the model.
:::
## Binary classification metrics
To illustrate other ways to measure model performance, we will switch to a different example. The `r pkg(modeldata)` package contains example predictions from a test data set with two classes ("Class1" and "Class2"):
```{r performance-two-class-example}
data(two_class_example)
str(two_class_example)
```
The second and third columns are the predicted class probabilities for the test set while `predicted` are the discrete predictions.
For the hard class predictions, there are a variety of `r pkg(yardstick)` functions that are helpful:
```{r performance-class-metrics}
# A confusion matrix:
conf_mat(two_class_example, truth = truth, estimate = predicted)
accuracy(two_class_example, truth = truth, estimate = predicted)
# Matthews correlation coefficient:
mcc(two_class_example, truth, predicted)
# F1 metric:
f_meas(two_class_example, truth, predicted)
```
For binary classification data sets, these functions have a standard argument called `event_level`. The _default_ is that the **first** level of the outcome factor is the event of interest.
```{block, type = "rmdnote"}
There is some heterogeneity in R functions in this regard; some use the first level and others the second to denote the event of interest. We consider it more intuitive that the first level is the most important. The second level logic is borne of encoding the outcome as 0/1 (in which case the second value is the event) and unfortunately remains in some packages. However, tidymodels (along with many other R packages) _require_ a categorical outcome to be encoded as a factor and, for this reason, the legacy justification for the second level as the event becomes irrelevant.
```
As an example where the second class is the event:
```{r performance-2nd-level}
f_meas(two_class_example, truth, predicted, event_level = "second")
```
In the output above, the `.estimator` value of "binary" indicates that the standard formula for binary classes will be used.
There are numerous classification metrics that use the predicted probabilities as inputs rather than the hard class predictions. For example, the receiver operating characteristic (ROC) curve computes the sensitivity and specificity over a continuum of different event thresholds. The predicted class column is not used. There are two `r pkg(yardstick)` functions for this method: `roc_curve()` computes the data points that make up the ROC curve and `roc_auc()` computes the area under the curve.
The interfaces to these types of metric functions use the `...` argument placeholder to pass in the appropriate class probability column. For two-class problems, the probability column for the event of interest is passed into the function:
```{r performance-2class-roc}
two_class_curve <- roc_curve(two_class_example, truth, Class1)
two_class_curve
roc_auc(two_class_example, truth, Class1)
```
The `two_class_curve` object can be used in a `ggplot` call to visualize the curve. There is an `autoplot()` method that will take care of the details:
```{r performance-2class-roc-curve}
autoplot(two_class_curve)
```
There are a number of other functions that use probability estimates, including `gain_curve()`, `lift_curve()`, and `pr_curve()`.
## Multi-class classification metrics
What about data with three or more classes? To demonstrate, let's explore a different example data set that has four classes:
```{r performance-hpc-example}
data(hpc_cv)
str(hpc_cv)
```
As before, there are factors for the observed and predicted outcomes along with four other columns of predicted probabilities for each class. These data also include a `Resample` column. These results are for out-of-sample predictions associated with 10-fold cross-validation (discussed in Chapter \@ref(resampling)). For the time being, this column will be ignored.
The functions for metrics that use the discrete class predictions are identical to their binary counterparts:
```{r performance-mutliclass-pred}
accuracy(hpc_cv, obs, pred)
mcc(hpc_cv, obs, pred)
```
Note that, in these results, a "multiclass" `.estimator` is listed. Like "binary", this indicates that the formula for outcomes with three or more class levels was used. The Matthews correlation coefficient was originally designed for two classes but has been extended to cases with more class levels.
There are methods for using metrics that are specific to outcomes with two classes for data sets with more than two classes. For example, a metric such as sensitivity measures the true positive rate which, by definition, is specific to two classes (i.e., "event" and "non-event"). How can this metric be used in our example data?
There are wrapper methods that can be used to apply sensitivity to our four-class outcome. These options are macro-, macro-weighted, and micro-averaging:
* Macro-averaging computes a set of one-versus-all metrics using the standard two-class statistics. These are averaged.
* Macro-weighted averaging does the same but the average is weighted by the number of samples in each class.
* Micro-averaging computes the contribution for each class, aggregates them, then computes a single metric from the aggregates.
See @wu2017unified and @OpitzBurst.
Using sensitivity as an example, the usual two-class calculation is the ratio of the number of correctly predicted events divided by the number of true events. The "manual" calculations for these averaging methods are:
```{r performance-sens-manual}
class_totals <-
count(hpc_cv, obs, name = "totals") %>%
mutate(class_wts = totals / sum(totals))
class_totals
cell_counts <-
hpc_cv %>%
group_by(obs, pred) %>%
count() %>%
ungroup()
# Compute the four sensitivities using 1-vs-all
one_versus_all <-
cell_counts %>%
filter(obs == pred) %>%
full_join(class_totals, by = "obs") %>%
mutate(sens = n / totals)
one_versus_all
# Three different estimates:
one_versus_all %>%
summarize(
macro = mean(sens),
macro_wts = weighted.mean(sens, class_wts),
micro = sum(n) / sum(totals)
)
```
Thankfully, there are easier methods for obtaining these results:
```{r performance-sens}
sensitivity(hpc_cv, obs, pred, estimator = "macro")
sensitivity(hpc_cv, obs, pred, estimator = "macro_weighted")
sensitivity(hpc_cv, obs, pred, estimator = "micro")
```
For metrics using probability estimates, there are some metrics with multi-class analogs. For example, @HandTill determined a multi-class technique for ROC curves. In this case, _all_ of the class probability columns must be given to the function:
```{r performance-multi-class-roc}
roc_auc(hpc_cv, obs, VF, F, M, L)
```
Macro-averaging is also available:
```{r performance-multi-class-roc-macro}
roc_auc(hpc_cv, obs, VF, F, M, L, estimator = "macro_weighted")
```
Finally, all of these performance metrics can be computed using `r pkg(dplyr)` groupings. Recall that these data have a column for the resampling groups. Passing a grouped data frame to the metric function will compute the metrics for each group:
```{r performance-multi-class-acc-grouped}
hpc_cv %>%
group_by(Resample) %>%
accuracy(obs, pred)
```
The groupings also translate to the `autoplot()` methods:
```{r performance-multi-class-roc-grouped}
# Four 1-vs-all ROC curves for each fold
hpc_cv %>%
group_by(Resample) %>%
roc_curve(obs, VF, F, M, L) %>%
autoplot()
```
This can be a quick visualization method for model effectiveness.
## Chapter summary {#performance-summary}
Functions from the `r pkg(yardstick)` package measure the effectiveness of a model using data. The primary interface is based on data frames (as opposed to having vector arguments). There are a variety of regression and classification metrics and, within these, there are sometimes different estimators for the statistics.