Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical shrinkage #692

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.16.1
Date: 2024-05-15
Date: 2024-05-16
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export(csrf)
export(deforest)
export(getTerminalNodeIDs)
export(holdoutRF)
export(hshrink)
export(importance)
export(importance_pvalues)
export(predictions)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

# ranger 0.16.1
* Set num.threads=2 as default; respect environment variables and options
* Add hierarchical shrinkage
* Allow vector min.node.size and min.bucket for class-specific limits

# ranger 0.16.0
Expand Down
12 changes: 12 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,15 @@ randomObsNode <- function(groups, y, inbag_counts) {
.Call(`_ranger_randomObsNode`, groups, y, inbag_counts)
}

hshrink_regr <- function(left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum) {
invisible(.Call(`_ranger_hshrink_regr`, left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum))
}

hshrink_prob <- function(left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum) {
invisible(.Call(`_ranger_hshrink_prob`, left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum))
}

replace_class_counts <- function(class_counts_old, class_counts_new) {
invisible(.Call(`_ranger_replace_class_counts`, class_counts_old, class_counts_new))
}

90 changes: 90 additions & 0 deletions R/hshrink.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -------------------------------------------------------------------------------
# This file is part of Ranger.
#
# Ranger is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ranger is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ranger. If not, see <http://www.gnu.org/licenses/>.
#
# Written by:
#
# Marvin N. Wright
# Institut fuer Medizinische Biometrie und Statistik
# Universitaet zu Luebeck
# Ratzeburger Allee 160
# 23562 Luebeck
# Germany
#
# http://www.imbs-luebeck.de
# -------------------------------------------------------------------------------


#' Hierarchical shrinkage
#'
#' Apply hierarchical shrinkage to a ranger object.
#' Hierarchical shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions.
#' For details see Agarwal et al. (2022).
#'
#' @param rf ranger object, created with \code{node.stats = TRUE}.
#' @param lambda Non-negative shrinkage parameter.
#'
#' @return The ranger object is modified in-place.
#'
#' @examples
##' @references
##' \itemize{
##' \item Agarwal, A., Tan, Y.S., Ronen, O., Singh, C. & Yu, B. (2022). Hierarchical Shrinkage: Improving the accuracy and interpretability of tree-based models. Proceedings of the 39th International Conference on Machine Learning, PMLR 162:111-135.
##' }
#' @author Marvin N. Wright
#' @export
hshrink <- function(rf, lambda) {
if (is.null(rf$forest$num.samples.nodes)) {
stop("Hierarchical shrinkage needs node statistics, set node.stats=TRUE in ranger() call.")
}
if (lambda < 0) {
stop("Shrinkage parameter lambda has to be non-negative.")
}

if (rf$treetype == "Regression") {
invisible(lapply(1:rf$num.trees, function(treeID) {
hshrink_regr(
rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]],
rf$forest$num.samples.nodes[[treeID]], rf$forest$node.predictions[[treeID]],
rf$forest$split.values[[treeID]], lambda, 0, 0, 0, 0
)
}))
} else if (rf$treetype == "Probability estimation") {
invisible(lapply(1:rf$num.trees, function(treeID) {
# Create temporary class frequency matrix
class_freq <- t(simplify2array(rf$forest$terminal.class.counts[[treeID]]))

parent_pred <- rep(0, length(rf$forest$class.values))
cum_sum <- rep(0, length(rf$forest$class.values))
hshrink_prob(
rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]],
rf$forest$num.samples.nodes[[treeID]], class_freq,
lambda, 0, 0, parent_pred, cum_sum
)

# Assign temporary matrix values back to ranger object
replace_class_counts(rf$forest$terminal.class.counts[[treeID]], class_freq)
}))
} else if (rf$treetype == "Classification") {
stop("To apply hierarchical shrinkage to classification forests, use probability=TRUE in the ranger() call.")
} else if (rf$treetype == "Survival") {
stop("Hierarchical shrinkage not yet implemented for survival.")
} else {
stop("Unknown treetype.")
}

}


29 changes: 29 additions & 0 deletions man/hshrink.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,62 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// hshrink_regr
void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions, Rcpp::NumericVector& split_values, double lambda, size_t nodeID, size_t parent_n, double parent_pred, double cum_sum);
RcppExport SEXP _ranger_hshrink_regr(SEXP left_childrenSEXP, SEXP right_childrenSEXP, SEXP num_samples_nodesSEXP, SEXP node_predictionsSEXP, SEXP split_valuesSEXP, SEXP lambdaSEXP, SEXP nodeIDSEXP, SEXP parent_nSEXP, SEXP parent_predSEXP, SEXP cum_sumSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type left_children(left_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type right_children(right_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type num_samples_nodes(num_samples_nodesSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector& >::type node_predictions(node_predictionsSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector& >::type split_values(split_valuesSEXP);
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
Rcpp::traits::input_parameter< size_t >::type nodeID(nodeIDSEXP);
Rcpp::traits::input_parameter< size_t >::type parent_n(parent_nSEXP);
Rcpp::traits::input_parameter< double >::type parent_pred(parent_predSEXP);
Rcpp::traits::input_parameter< double >::type cum_sum(cum_sumSEXP);
hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum);
return R_NilValue;
END_RCPP
}
// hshrink_prob
void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericMatrix& class_freq, double lambda, size_t nodeID, size_t parent_n, Rcpp::NumericVector parent_pred, Rcpp::NumericVector cum_sum);
RcppExport SEXP _ranger_hshrink_prob(SEXP left_childrenSEXP, SEXP right_childrenSEXP, SEXP num_samples_nodesSEXP, SEXP class_freqSEXP, SEXP lambdaSEXP, SEXP nodeIDSEXP, SEXP parent_nSEXP, SEXP parent_predSEXP, SEXP cum_sumSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type left_children(left_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type right_children(right_childrenSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type num_samples_nodes(num_samples_nodesSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type class_freq(class_freqSEXP);
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
Rcpp::traits::input_parameter< size_t >::type nodeID(nodeIDSEXP);
Rcpp::traits::input_parameter< size_t >::type parent_n(parent_nSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type parent_pred(parent_predSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type cum_sum(cum_sumSEXP);
hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum);
return R_NilValue;
END_RCPP
}
// replace_class_counts
void replace_class_counts(Rcpp::List& class_counts_old, Rcpp::NumericMatrix& class_counts_new);
RcppExport SEXP _ranger_replace_class_counts(SEXP class_counts_oldSEXP, SEXP class_counts_newSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::List& >::type class_counts_old(class_counts_oldSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type class_counts_new(class_counts_newSEXP);
replace_class_counts(class_counts_old, class_counts_new);
return R_NilValue;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 50},
{"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2},
{"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3},
{"_ranger_hshrink_regr", (DL_FUNC) &_ranger_hshrink_regr, 10},
{"_ranger_hshrink_prob", (DL_FUNC) &_ranger_hshrink_prob, 9},
{"_ranger_replace_class_counts", (DL_FUNC) &_ranger_replace_class_counts, 2},
{NULL, NULL, 0}
};

Expand Down
63 changes: 63 additions & 0 deletions src/utilityRcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,66 @@ Rcpp::NumericMatrix randomObsNode(Rcpp::IntegerMatrix groups, Rcpp::NumericVecto
return result;
}

// Recursive function for hierarchical shrinkage (regression)
//[[Rcpp::export]]
void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children,
Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions,
Rcpp::NumericVector& split_values, double lambda,
size_t nodeID, size_t parent_n, double parent_pred, double cum_sum) {
if (nodeID == 0) {
// In the root, just use the prediction
cum_sum = node_predictions[nodeID];
} else {
// If not root, use shrinkage formula
cum_sum += (node_predictions[nodeID] - parent_pred) / (1 + lambda/parent_n);
}

if (left_children[nodeID] == 0) {
// If leaf, change node prediction in split_values (used for prediction)
split_values[nodeID] = cum_sum;
} else {
// If not leaf, give weighted prediction to child nodes
hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values,
lambda, left_children[nodeID], num_samples_nodes[nodeID], node_predictions[nodeID],
cum_sum);
hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values,
lambda, right_children[nodeID], num_samples_nodes[nodeID], node_predictions[nodeID],
cum_sum);
}
}

// Recursive function for hierarchical shrinkage (probability)
//[[Rcpp::export]]
void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children,
Rcpp::IntegerVector& num_samples_nodes,
Rcpp::NumericMatrix& class_freq, double lambda,
size_t nodeID, size_t parent_n, Rcpp::NumericVector parent_pred, Rcpp::NumericVector cum_sum) {

if (nodeID == 0) {
// In the root, just use the prediction
cum_sum = class_freq(nodeID, Rcpp::_);
} else {
// If not root, use shrinkage formula
cum_sum += (class_freq(nodeID, Rcpp::_) - parent_pred) / (1 + lambda/parent_n);
}

if (left_children[nodeID] == 0) {
// If leaf, change node prediction in split_values (used for prediction)
class_freq(nodeID, Rcpp::_) = cum_sum;
} else {
// If not leaf, give weighted prediction to child nodes
hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda,
left_children[nodeID], num_samples_nodes[nodeID], class_freq(nodeID, Rcpp::_), clone(cum_sum));
hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda,
right_children[nodeID], num_samples_nodes[nodeID], class_freq(nodeID, Rcpp::_), clone(cum_sum));
}
}

// Replace class counts list(vector) with values from matrix
//[[Rcpp::export]]
void replace_class_counts(Rcpp::List& class_counts_old, Rcpp::NumericMatrix& class_counts_new) {
for (size_t i = 0; i < class_counts_old.size(); ++i) {
class_counts_old[i] = class_counts_new(i, Rcpp::_);
}
}

68 changes: 68 additions & 0 deletions tests/testthat/test_hshrink.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## Tests for hierarchical shrinkage

library(ranger)
context("ranger_hshrink")

## Tests
test_that("hierarchical shrinkage gives an error when node.stats=FALSE", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = FALSE)
expect_error(hshrink(rf, lambda = 5))
})

test_that("hierarchical shrinkage does not work for hard classification", {
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = FALSE)
expect_error(hshrink(rf, lambda = 5))
})

test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, regression", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = TRUE)
split_values_before <- rf$forest$split.values[[1]]
pred_before <- predict(rf, iris)$predictions
hshrink(rf, lambda = 0)
split_values_after <- rf$forest$split.values[[1]]
pred_after <- predict(rf, iris)$predictions
expect_equal(split_values_before, split_values_after)
expect_equal(pred_before, pred_after)
})

test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, probability", {
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE)
class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_before <- predict(rf, iris)$predictions
hshrink(rf, lambda = 0)
class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_after <- predict(rf, iris)$predictions
expect_equal(class_freq_before, class_freq_after)
expect_equal(pred_before, pred_after)
})

test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, regression", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, replace = FALSE, sample.fraction = 1, node.stats = TRUE)
split_values_before <- rf$forest$split.values[[1]]
pred_before <- predict(rf, iris)$predictions
split_values_before[1] <- 0 # Modify to create deep copy
hshrink(rf, lambda = 100)
split_values_after <- rf$forest$split.values[[1]]
split_values_after[1] <- 0 # Also modify here
pred_after <- predict(rf, iris)$predictions
expect_false(all(split_values_before == split_values_after))

# Shrinkage reduces variance
expect_lt(var(pred_after), var(pred_before))

})

test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, probability", {
rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE)
class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_before <- predict(rf, iris)$predictions
hshrink(rf, lambda = 100)
class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]])
pred_after <- predict(rf, iris)$predictions
expect_false(all(class_freq_before == class_freq_after))

# Shrinkage reduces variance
expect_lt(var(pred_after[, 1]), var(pred_before[, 1]))
expect_lt(var(pred_after[, 2]), var(pred_before[, 2]))
expect_lt(var(pred_after[, 3]), var(pred_before[, 3]))
})