forked from grantmcdermott/parttree
-
Notifications
You must be signed in to change notification settings - Fork 0
/
README.Rmd
182 lines (145 loc) · 5.76 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# parttree
<!-- badges: start -->
<!-- badges: end -->
A set of simple functions for visualizing decision tree partitions in R with
[**ggplot2**](https://ggplot2.tidyverse.org/).
## Installation
This package is not yet on CRAN, but can be installed from [GitHub](https://github.com/)
with:
``` r
# install.packages("remotes")
remotes::install_github("grantmcdermott/parttree")
```
## Example
The main function that users will interact with is `geom_parttree()`. Here's a
simple example using the [palmerpenguins](https://allisonhorst.github.io/palmerpenguins/)
dataset.
```{r penguin_plot}
library(palmerpenguins) ## For 'penguins' dataset
library(rpart) ## For fitting decisions trees
library(parttree) ## This package (will automatically load ggplot2 too)
## First construct a scatterplot of the raw penguin data
p = ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
geom_point(aes(col = species)) +
theme_minimal()
## Fit a decision tree using the same variables as the above plot
tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins)
## Visualise the tree partitions by adding it via geom_parttree()
p +
geom_parttree(data = tree, aes(fill=species), alpha = 0.1) +
labs(caption = "Note: Points denote observed data. Shaded regions denote tree predictions.")
```
Trees with continuous independent variables are also supported. However, I
recommend adjusting the plot fill aesthetic, since your tree will likely
partition the data into intervals that don't match up exactly with the raw data.
```{r penguin_plot2}
tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins)
p2 =
ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) +
geom_point(aes(col = body_mass_g)) +
theme_minimal()
## Legend scales don't quite match (try it yourself)
# p2
## Better to scale fill to the original data
## This does the job but is still kind of hard to make out (again, try yourself)
# p2 +
# scale_fill_continuous(limits = range(penguins$body_mass_g, na.rm = TRUE))
## Even better to combine fill scaling with a mixed colour palette
p2 +
scale_colour_viridis_c(
limits = range(penguins$body_mass_g, na.rm = TRUE),
aesthetics = c('colour', 'fill')
)
```
## Limitations and caveats
### Supported model classes
Currently, the package only works with decision trees created by the
[**rpart**](https://cran.r-project.org/web/packages/rpart/index.html) package.
However, it does support other front-end modes that call `rpart::rpart()` as
the underlying engine; in particular the
[**parsnip**](https://tidymodels.github.io/parsnip/) and
[**mlr3**](https://mlr3.mlr-org.com/) packages. Here's an example with the
former.
```{r titanic_plot}
library(parsnip)
library(titanic) ## Just for a different data set
set.seed(123) ## For consistent jitter
titanic_train$Survived = as.factor(titanic_train$Survived)
## Build our tree using parsnip (but with rpart as the model engine)
ti_tree =
decision_tree() %>%
set_engine("rpart") %>%
set_mode("classification") %>%
fit(Survived ~ Pclass + Age, data = titanic_train)
## Plot the data and model partitions
titanic_train %>%
ggplot(aes(x=Pclass, y=Age)) +
geom_jitter(aes(col=Survived), alpha=0.7) +
geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) +
theme_minimal()
```
### Plot orientation
Underneath the hood, `geom_parttree()` is calling the companion `parttree()`
function, which coerces the **rpart** tree object into a data frame that is
easily understood by **ggplot2**. For example, consider again our first "tree"
model from earlier. Here's the print output of the raw model.
```{r tree}
tree
```
And here's what we get after we feed it to `parttree()`.
```{r tree_parted}
parttree(tree)
```
Again, the resulting data frame is designed to be amenable to a **ggplot2** geom
layer, with columns like `xmin`, `xmax`, etc. specifying aesthetics that
**ggplot2** recognises. (Fun fact: `geom_parttree()` is really just a thin
wrapper around `geom_rect()`.) The goal of the package is to abstract away these
kinds of details
from the user, so we can just specify `geom_parttree()` — with a valid
tree object as the data input — and be done with it. However, while this
generally works well, it can sometimes lead to unexpected behaviour in terms of
plot orientation. That's because it's hard to guess ahead of time what the user
will specify as the x and y variables (i.e. axes) in their other plot layers. To
see what I mean, let's redo our penguin plot from earlier, but this time switch
the axes in the main `ggplot()` call.
```{r tree_plot_mismatch}
## First, redo our first plot but this time switch the x and y variables
p3 =
ggplot(
data = penguins,
aes(x = bill_length_mm, y = flipper_length_mm) ## Switched!
) +
geom_point(aes(col = species)) +
theme_minimal()
## Add on our tree (and some preemptive titling..)
p3 +
geom_parttree(data = tree, aes(fill = species), alpha = 0.1) +
labs(
title = "Oops!",
subtitle = "Looks like a mismatch between our x and y axes..."
)
```
As was the case here, this kind of orientation mismatch normally (hopefully) be
pretty easy to recognize. To fix, we can use the `flipaxes = TRUE` argument to
flip the orientation of the `geom_parttree` layer.
```{r tree_plot_flip}
p3 +
geom_parttree(
data = tree, aes(fill = species), alpha = 0.1,
flipaxes = TRUE ## Flip the orientation
) +
labs(title = "That's better")
```