-
Notifications
You must be signed in to change notification settings - Fork 0
/
03_model.R
136 lines (95 loc) · 3.82 KB
/
03_model.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
source("00_packages.R")
# qs-files zijn al gecleaned bij de explore
data <- qread("data/brfss/brfss_level_1.qs") |> filter(!is.na(general_health))
#data <- qread("data/brfss/brfss_level_2.qs") |> filter(!is.na(general_health))
set.seed(123)
data_split <- initial_split(data, strata = general_health)
train_data <- training(data_split)
test_data <- testing(data_split)
folds <- vfold_cv(train_data, v = 10, repeats = 1, strata = general_health)
count(train_data, general_health)
# define recipe for classification of general health ----------------------
gh_rec <- recipe(general_health ~ . , data = train_data) |>
update_role(id , new_role = "id") |>
step_impute_median(all_numeric_predictors()) |>
step_impute_mode(all_nominal_predictors()) |>
step_unknown((all_nominal_predictors())) |>
step_other(all_nominal_predictors()) |>
step_relevel(have_any_health_care_coverage, ref_level = "Yes") |>
step_relevel(could_not_see_doctor_because_of_cost, ref_level = "No") |>
step_relevel(smoked_at_least_100_cigarettes, ref_level = "No") |>
step_dummy(all_nominal_predictors()) |>
step_nzv(all_predictors()) |>
step_normalize() |>
# make dataset balanced
#step_smote(general_health) |>
prep()
summary(gh_rec)
# hoe ziet de data voor het model eruit? nog NA?
baked_data <- bake(gh_rec, new_data = train_data)
map_dbl(baked_data, ~sum(is.na(.)))
# Define some classification engines
# random forest classification
rf_class <- rand_forest(trees = 1000) %>%
set_mode("classification") %>%
set_engine("ranger")
# decision tree classification
dt_class <- decision_tree() |>
set_mode("classification") |>
set_engine("rpart")
xgb_class <- boost_tree() |>
set_mode("classification") |>
set_engine("xgboost")
# make workflow for general_health, using random forest
base_wf <- workflow() |>
add_recipe(gh_rec)
# fit models ----------
# fit randomforest workflow on train_data
fit_rf_gh <- base_wf |>
add_model(rf_class) |>
fit(train_data)
# fit decision tree on train_data
fit_dt_gh <- base_wf |>
add_model(dt_class) |>
fit(train_data)
fit_xgb_gh <- base_wf |>
add_model(xgb_class) |>
fit(train_data)
# Evaluate performance on train data ------
# voeg de voorspellingen toe aan de trainingdata
train_data <- train_data |>
mutate(.pred_rf = predict(fit_rf_gh, train_data)[[".pred_class"]],
.pred_dt = predict(fit_dt_gh, train_data)[[".pred_class"]],
.pred_xgb = predict(fit_xgb_gh, train_data)[[".pred_class"]],
)
train_data |>
select(general_health, starts_with(".pred")) |>
View()
accuracy(train_data, general_health, .pred_rf)
accuracy(train_data, general_health, .pred_dt)
accuracy(train_data, general_health, .pred_xgb)
# voeg de RF voorspelling toe aan de testdata
test_data_rf <- augment(fit_rf_gh, test_data)
test_data_rf |> select(general_health, starts_with(".pred"))
# voeg de Decision Tree voorspelling toe aan de testdata
test_data_dt <- augment(fit_dt_gh, test_data)
test_data_dt |> select(general_health, starts_with(".pred"))
# voeg de Decision Tree voorspelling toe aan de testdata
test_data_xgb <- augment(fit_xgb_gh, test_data)
test_data_xgb |> select(general_health, starts_with(".pred"))
# hoe goed is de accuracy op de testdata
accuracy(test_data_rf, truth = general_health, estimate = .pred_class)
accuracy(test_data_dt, truth = general_health, estimate = .pred_class)
accuracy(test_data_xgb, truth = general_health, estimate = .pred_class)
precision(train_data, truth = general_health, estimate = .pred_rf)
recall(train_data, truth = general_health, estimate = .pred_rf)
rf_class %>%
set_engine("ranger", importance = "permutation") %>%
fit(
general_health ~ .,
data = bake(gh_rec, new_data = train_data)
) %>%
vip(geom = "point")
tree <- rpart(general_health ~ ., data = baked_data)
(vi_tree <- tree$variable.importance)
barplot(vi_tree, horiz = TRUE, las = 1)