-
-
Notifications
You must be signed in to change notification settings - Fork 194
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #692 from imbs-hl/hierarchical_shrinkage
Hierarchical shrinkage
- Loading branch information
Showing
9 changed files
with
317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ Package: ranger | |
Type: Package | ||
Title: A Fast Implementation of Random Forests | ||
Version: 0.16.1 | ||
Date: 2024-05-15 | ||
Date: 2024-05-16 | ||
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb] | ||
Maintainer: Marvin N. Wright <[email protected]> | ||
Description: A fast implementation of Random Forests, particularly suited for high | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# ------------------------------------------------------------------------------- | ||
# This file is part of Ranger. | ||
# | ||
# Ranger is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# Ranger is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with Ranger. If not, see <http://www.gnu.org/licenses/>. | ||
# | ||
# Written by: | ||
# | ||
# Marvin N. Wright | ||
# Institut fuer Medizinische Biometrie und Statistik | ||
# Universitaet zu Luebeck | ||
# Ratzeburger Allee 160 | ||
# 23562 Luebeck | ||
# Germany | ||
# | ||
# http://www.imbs-luebeck.de | ||
# ------------------------------------------------------------------------------- | ||
|
||
|
||
#' Hierarchical shrinkage | ||
#' | ||
#' Apply hierarchical shrinkage to a ranger object. | ||
#' Hierarchical shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. | ||
#' For details see Agarwal et al. (2022). | ||
#' | ||
#' @param rf ranger object, created with \code{node.stats = TRUE}. | ||
#' @param lambda Non-negative shrinkage parameter. | ||
#' | ||
#' @return The ranger object is modified in-place. | ||
#' | ||
#' @examples | ||
##' @references | ||
##' \itemize{ | ||
##' \item Agarwal, A., Tan, Y.S., Ronen, O., Singh, C. & Yu, B. (2022). Hierarchical Shrinkage: Improving the accuracy and interpretability of tree-based models. Proceedings of the 39th International Conference on Machine Learning, PMLR 162:111-135. | ||
##' } | ||
#' @author Marvin N. Wright | ||
#' @export | ||
hshrink <- function(rf, lambda) { | ||
if (is.null(rf$forest$num.samples.nodes)) { | ||
stop("Hierarchical shrinkage needs node statistics, set node.stats=TRUE in ranger() call.") | ||
} | ||
if (lambda < 0) { | ||
stop("Shrinkage parameter lambda has to be non-negative.") | ||
} | ||
|
||
if (rf$treetype == "Regression") { | ||
invisible(lapply(1:rf$num.trees, function(treeID) { | ||
hshrink_regr( | ||
rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]], | ||
rf$forest$num.samples.nodes[[treeID]], rf$forest$node.predictions[[treeID]], | ||
rf$forest$split.values[[treeID]], lambda, 0, 0, 0, 0 | ||
) | ||
})) | ||
} else if (rf$treetype == "Probability estimation") { | ||
invisible(lapply(1:rf$num.trees, function(treeID) { | ||
# Create temporary class frequency matrix | ||
class_freq <- t(simplify2array(rf$forest$terminal.class.counts[[treeID]])) | ||
|
||
parent_pred <- rep(0, length(rf$forest$class.values)) | ||
cum_sum <- rep(0, length(rf$forest$class.values)) | ||
hshrink_prob( | ||
rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]], | ||
rf$forest$num.samples.nodes[[treeID]], class_freq, | ||
lambda, 0, 0, parent_pred, cum_sum | ||
) | ||
|
||
# Assign temporary matrix values back to ranger object | ||
replace_class_counts(rf$forest$terminal.class.counts[[treeID]], class_freq) | ||
})) | ||
} else if (rf$treetype == "Classification") { | ||
stop("To apply hierarchical shrinkage to classification forests, use probability=TRUE in the ranger() call.") | ||
} else if (rf$treetype == "Survival") { | ||
stop("Hierarchical shrinkage not yet implemented for survival.") | ||
} else { | ||
stop("Unknown treetype.") | ||
} | ||
|
||
} | ||
|
||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
## Tests for hierarchical shrinkage | ||
|
||
library(ranger) | ||
context("ranger_hshrink") | ||
|
||
## Tests | ||
test_that("hierarchical shrinkage gives an error when node.stats=FALSE", { | ||
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = FALSE) | ||
expect_error(hshrink(rf, lambda = 5)) | ||
}) | ||
|
||
test_that("hierarchical shrinkage does not work for hard classification", { | ||
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = FALSE) | ||
expect_error(hshrink(rf, lambda = 5)) | ||
}) | ||
|
||
test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, regression", { | ||
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = TRUE) | ||
split_values_before <- rf$forest$split.values[[1]] | ||
pred_before <- predict(rf, iris)$predictions | ||
hshrink(rf, lambda = 0) | ||
split_values_after <- rf$forest$split.values[[1]] | ||
pred_after <- predict(rf, iris)$predictions | ||
expect_equal(split_values_before, split_values_after) | ||
expect_equal(pred_before, pred_after) | ||
}) | ||
|
||
test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, probability", { | ||
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE) | ||
class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]]) | ||
pred_before <- predict(rf, iris)$predictions | ||
hshrink(rf, lambda = 0) | ||
class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]]) | ||
pred_after <- predict(rf, iris)$predictions | ||
expect_equal(class_freq_before, class_freq_after) | ||
expect_equal(pred_before, pred_after) | ||
}) | ||
|
||
test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, regression", { | ||
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, replace = FALSE, sample.fraction = 1, node.stats = TRUE) | ||
split_values_before <- rf$forest$split.values[[1]] | ||
pred_before <- predict(rf, iris)$predictions | ||
split_values_before[1] <- 0 # Modify to create deep copy | ||
hshrink(rf, lambda = 100) | ||
split_values_after <- rf$forest$split.values[[1]] | ||
split_values_after[1] <- 0 # Also modify here | ||
pred_after <- predict(rf, iris)$predictions | ||
expect_false(all(split_values_before == split_values_after)) | ||
|
||
# Shrinkage reduces variance | ||
expect_lt(var(pred_after), var(pred_before)) | ||
|
||
}) | ||
|
||
test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, probability", { | ||
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE) | ||
class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]]) | ||
pred_before <- predict(rf, iris)$predictions | ||
hshrink(rf, lambda = 100) | ||
class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]]) | ||
pred_after <- predict(rf, iris)$predictions | ||
expect_false(all(class_freq_before == class_freq_after)) | ||
|
||
# Shrinkage reduces variance | ||
expect_lt(var(pred_after[, 1]), var(pred_before[, 1])) | ||
expect_lt(var(pred_after[, 2]), var(pred_before[, 2])) | ||
expect_lt(var(pred_after[, 3]), var(pred_before[, 3])) | ||
}) |