Skip to content

Commit

Permalink
Merge pull request #692 from imbs-hl/hierarchical_shrinkage
Browse files Browse the repository at this point in the history
Hierarchical shrinkage
  • Loading branch information
mnwright authored May 16, 2024
2 parents f570f7a + 7da74a8 commit e3047c6
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.16.1
Date: 2024-05-15
Date: 2024-05-16
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export(csrf)
export(deforest)
export(getTerminalNodeIDs)
export(holdoutRF)
export(hshrink)
export(importance)
export(importance_pvalues)
export(predictions)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

# ranger 0.16.1
* Set num.threads=2 as default; respect environment variables and options
* Add hierarchical shrinkage
* Allow vector min.node.size and min.bucket for class-specific limits

# ranger 0.16.0
Expand Down
12 changes: 12 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,15 @@ randomObsNode <- function(groups, y, inbag_counts) {
.Call(`_ranger_randomObsNode`, groups, y, inbag_counts)
}

hshrink_regr <- function(left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum) {
invisible(.Call(`_ranger_hshrink_regr`, left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum))
}

hshrink_prob <- function(left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum) {
invisible(.Call(`_ranger_hshrink_prob`, left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum))
}

replace_class_counts <- function(class_counts_old, class_counts_new) {
invisible(.Call(`_ranger_replace_class_counts`, class_counts_old, class_counts_new))
}

90 changes: 90 additions & 0 deletions R/hshrink.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -------------------------------------------------------------------------------
# This file is part of Ranger.
#
# Ranger is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ranger is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ranger. If not, see <http://www.gnu.org/licenses/>.
#
# Written by:
#
# Marvin N. Wright
# Institut fuer Medizinische Biometrie und Statistik
# Universitaet zu Luebeck
# Ratzeburger Allee 160
# 23562 Luebeck
# Germany
#
# http://www.imbs-luebeck.de
# -------------------------------------------------------------------------------


#' Hierarchical shrinkage
#'
#' Apply hierarchical shrinkage to a ranger object.
#' Hierarchical shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions.
#' For details see Agarwal et al. (2022).
#'
#' @param rf ranger object, created with \code{node.stats = TRUE}.
#' @param lambda Non-negative shrinkage parameter.
#'
#' @return The ranger object is modified in-place.
#'
#' @examples
##' @references
##' \itemize{
##' \item Agarwal, A., Tan, Y.S., Ronen, O., Singh, C. & Yu, B. (2022). Hierarchical Shrinkage: Improving the accuracy and interpretability of tree-based models. Proceedings of the 39th International Conference on Machine Learning, PMLR 162:111-135.
##' }
#' @author Marvin N. Wright
#' @export
hshrink <- function(rf, lambda) {
if (is.null(rf$forest$num.samples.nodes)) {
stop("Hierarchical shrinkage needs node statistics, set node.stats=TRUE in ranger() call.")
}
if (lambda < 0) {
stop("Shrinkage parameter lambda has to be non-negative.")
}

if (rf$treetype == "Regression") {
invisible(lapply(1:rf$num.trees, function(treeID) {
hshrink_regr(
rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]],
rf$forest$num.samples.nodes[[treeID]], rf$forest$node.predictions[[treeID]],
rf$forest$split.values[[treeID]], lambda, 0, 0, 0, 0
)
}))
} else if (rf$treetype == "Probability estimation") {
invisible(lapply(1:rf$num.trees, function(treeID) {
# Create temporary class frequency matrix
class_freq <- t(simplify2array(rf$forest$terminal.class.counts[[treeID]]))

parent_pred <- rep(0, length(rf$forest$class.values))
cum_sum <- rep(0, length(rf$forest$class.values))
hshrink_prob(
rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]],
rf$forest$num.samples.nodes[[treeID]], class_freq,
lambda, 0, 0, parent_pred, cum_sum
)

# Assign temporary matrix values back to ranger object
replace_class_counts(rf$forest$terminal.class.counts[[treeID]], class_freq)
}))
} else if (rf$treetype == "Classification") {
stop("To apply hierarchical shrinkage to classification forests, use probability=TRUE in the ranger() call.")
} else if (rf$treetype == "Survival") {
stop("Hierarchical shrinkage not yet implemented for survival.")
} else {
stop("Unknown treetype.")
}

}


29 changes: 29 additions & 0 deletions man/hshrink.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,62 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// hshrink_regr
void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions, Rcpp::NumericVector& split_values, double lambda, size_t nodeID, size_t parent_n, double parent_pred, double cum_sum);
RcppExport SEXP _ranger_hshrink_regr(SEXP left_childrenSEXP, SEXP right_childrenSEXP, SEXP num_samples_nodesSEXP, SEXP node_predictionsSEXP, SEXP split_valuesSEXP, SEXP lambdaSEXP, SEXP nodeIDSEXP, SEXP parent_nSEXP, SEXP parent_predSEXP, SEXP cum_sumSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type left_children(left_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type right_children(right_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type num_samples_nodes(num_samples_nodesSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector& >::type node_predictions(node_predictionsSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector& >::type split_values(split_valuesSEXP);
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
Rcpp::traits::input_parameter< size_t >::type nodeID(nodeIDSEXP);
Rcpp::traits::input_parameter< size_t >::type parent_n(parent_nSEXP);
Rcpp::traits::input_parameter< double >::type parent_pred(parent_predSEXP);
Rcpp::traits::input_parameter< double >::type cum_sum(cum_sumSEXP);
hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum);
return R_NilValue;
END_RCPP
}
// hshrink_prob
void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericMatrix& class_freq, double lambda, size_t nodeID, size_t parent_n, Rcpp::NumericVector parent_pred, Rcpp::NumericVector cum_sum);
RcppExport SEXP _ranger_hshrink_prob(SEXP left_childrenSEXP, SEXP right_childrenSEXP, SEXP num_samples_nodesSEXP, SEXP class_freqSEXP, SEXP lambdaSEXP, SEXP nodeIDSEXP, SEXP parent_nSEXP, SEXP parent_predSEXP, SEXP cum_sumSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type left_children(left_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type right_children(right_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type num_samples_nodes(num_samples_nodesSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type class_freq(class_freqSEXP);
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
Rcpp::traits::input_parameter< size_t >::type nodeID(nodeIDSEXP);
Rcpp::traits::input_parameter< size_t >::type parent_n(parent_nSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type parent_pred(parent_predSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type cum_sum(cum_sumSEXP);
hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum);
return R_NilValue;
END_RCPP
}
// replace_class_counts
void replace_class_counts(Rcpp::List& class_counts_old, Rcpp::NumericMatrix& class_counts_new);
RcppExport SEXP _ranger_replace_class_counts(SEXP class_counts_oldSEXP, SEXP class_counts_newSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::List& >::type class_counts_old(class_counts_oldSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type class_counts_new(class_counts_newSEXP);
replace_class_counts(class_counts_old, class_counts_new);
return R_NilValue;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 50},
{"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2},
{"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3},
{"_ranger_hshrink_regr", (DL_FUNC) &_ranger_hshrink_regr, 10},
{"_ranger_hshrink_prob", (DL_FUNC) &_ranger_hshrink_prob, 9},
{"_ranger_replace_class_counts", (DL_FUNC) &_ranger_replace_class_counts, 2},
{NULL, NULL, 0}
};

Expand Down
63 changes: 63 additions & 0 deletions src/utilityRcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,66 @@ Rcpp::NumericMatrix randomObsNode(Rcpp::IntegerMatrix groups, Rcpp::NumericVecto
return result;
}

// Recursive function for hierarchical shrinkage (regression)
//[[Rcpp::export]]
void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children,
Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions,
Rcpp::NumericVector& split_values, double lambda,
size_t nodeID, size_t parent_n, double parent_pred, double cum_sum) {
if (nodeID == 0) {
// In the root, just use the prediction
cum_sum = node_predictions[nodeID];
} else {
// If not root, use shrinkage formula
cum_sum += (node_predictions[nodeID] - parent_pred) / (1 + lambda/parent_n);
}

if (left_children[nodeID] == 0) {
// If leaf, change node prediction in split_values (used for prediction)
split_values[nodeID] = cum_sum;
} else {
// If not leaf, give weighted prediction to child nodes
hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values,
lambda, left_children[nodeID], num_samples_nodes[nodeID], node_predictions[nodeID],
cum_sum);
hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values,
lambda, right_children[nodeID], num_samples_nodes[nodeID], node_predictions[nodeID],
cum_sum);
}
}

// Recursive function for hierarchical shrinkage (probability)
//[[Rcpp::export]]
void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children,
Rcpp::IntegerVector& num_samples_nodes,
Rcpp::NumericMatrix& class_freq, double lambda,
size_t nodeID, size_t parent_n, Rcpp::NumericVector parent_pred, Rcpp::NumericVector cum_sum) {

if (nodeID == 0) {
// In the root, just use the prediction
cum_sum = class_freq(nodeID, Rcpp::_);
} else {
// If not root, use shrinkage formula
cum_sum += (class_freq(nodeID, Rcpp::_) - parent_pred) / (1 + lambda/parent_n);
}

if (left_children[nodeID] == 0) {
// If leaf, change node prediction in split_values (used for prediction)
class_freq(nodeID, Rcpp::_) = cum_sum;
} else {
// If not leaf, give weighted prediction to child nodes
hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda,
left_children[nodeID], num_samples_nodes[nodeID], class_freq(nodeID, Rcpp::_), clone(cum_sum));
hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda,
right_children[nodeID], num_samples_nodes[nodeID], class_freq(nodeID, Rcpp::_), clone(cum_sum));
}
}

// Replace class counts list(vector) with values from matrix
//[[Rcpp::export]]
void replace_class_counts(Rcpp::List& class_counts_old, Rcpp::NumericMatrix& class_counts_new) {
for (size_t i = 0; i < class_counts_old.size(); ++i) {
class_counts_old[i] = class_counts_new(i, Rcpp::_);
}
}

68 changes: 68 additions & 0 deletions tests/testthat/test_hshrink.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## Tests for hierarchical shrinkage

library(ranger)
context("ranger_hshrink")

## Tests
test_that("hierarchical shrinkage gives an error when node.stats=FALSE", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = FALSE)
expect_error(hshrink(rf, lambda = 5))
})

test_that("hierarchical shrinkage does not work for hard classification", {
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = FALSE)
expect_error(hshrink(rf, lambda = 5))
})

test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, regression", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = TRUE)
split_values_before <- rf$forest$split.values[[1]]
pred_before <- predict(rf, iris)$predictions
hshrink(rf, lambda = 0)
split_values_after <- rf$forest$split.values[[1]]
pred_after <- predict(rf, iris)$predictions
expect_equal(split_values_before, split_values_after)
expect_equal(pred_before, pred_after)
})

test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, probability", {
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE)
class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_before <- predict(rf, iris)$predictions
hshrink(rf, lambda = 0)
class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_after <- predict(rf, iris)$predictions
expect_equal(class_freq_before, class_freq_after)
expect_equal(pred_before, pred_after)
})

test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, regression", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, replace = FALSE, sample.fraction = 1, node.stats = TRUE)
split_values_before <- rf$forest$split.values[[1]]
pred_before <- predict(rf, iris)$predictions
split_values_before[1] <- 0 # Modify to create deep copy
hshrink(rf, lambda = 100)
split_values_after <- rf$forest$split.values[[1]]
split_values_after[1] <- 0 # Also modify here
pred_after <- predict(rf, iris)$predictions
expect_false(all(split_values_before == split_values_after))

# Shrinkage reduces variance
expect_lt(var(pred_after), var(pred_before))

})

test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, probability", {
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE)
class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_before <- predict(rf, iris)$predictions
hshrink(rf, lambda = 100)
class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_after <- predict(rf, iris)$predictions
expect_false(all(class_freq_before == class_freq_after))

# Shrinkage reduces variance
expect_lt(var(pred_after[, 1]), var(pred_before[, 1]))
expect_lt(var(pred_after[, 2]), var(pred_before[, 2]))
expect_lt(var(pred_after[, 3]), var(pred_before[, 3]))
})

0 comments on commit e3047c6

Please sign in to comment.