From 53f9c924c9d543297c6ac2693bc7a23886b76307 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Fri, 15 Sep 2023 14:34:03 +0200 Subject: [PATCH 1/2] add horizontal shrinkage --- DESCRIPTION | 2 +- NAMESPACE | 1 + NEWS.md | 1 + R/RcppExports.R | 12 +++++ R/hshrink.R | 90 +++++++++++++++++++++++++++++++++++ man/hshrink.Rd | 29 +++++++++++ src/RcppExports.cpp | 51 ++++++++++++++++++++ src/utilityRcpp.cpp | 63 ++++++++++++++++++++++++ tests/testthat/test_hshrink.R | 68 ++++++++++++++++++++++++++ 9 files changed, 316 insertions(+), 1 deletion(-) create mode 100644 R/hshrink.R create mode 100644 man/hshrink.Rd create mode 100644 tests/testthat/test_hshrink.R diff --git a/DESCRIPTION b/DESCRIPTION index 20ef02c46..dc27d2a2a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,7 @@ Package: ranger Type: Package Title: A Fast Implementation of Random Forests Version: 0.15.4 -Date: 2023-09-12 +Date: 2023-09-15 Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb] Maintainer: Marvin N. Wright Description: A fast implementation of Random Forests, particularly suited for high diff --git a/NAMESPACE b/NAMESPACE index 3aad2ad21..64d06c481 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,6 +17,7 @@ export(csrf) export(deforest) export(getTerminalNodeIDs) export(holdoutRF) +export(hshrink) export(importance) export(importance_pvalues) export(predictions) diff --git a/NEWS.md b/NEWS.md index 0795b254e..402c6f71b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,7 @@ # ranger 0.15.4 * Add node.stats option to save node statistics of all nodes +* Add horizontal shrinkage # ranger 0.15.3 * Fix min bucket option in C++ version diff --git a/R/RcppExports.R b/R/RcppExports.R index 1de8f9d4c..3d8790001 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -13,3 +13,15 @@ randomObsNode <- function(groups, y, inbag_counts) { .Call(`_ranger_randomObsNode`, groups, y, inbag_counts) } +hshrink_regr <- function(left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum) { + invisible(.Call(`_ranger_hshrink_regr`, left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum)) +} + +hshrink_prob <- function(left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum) { + invisible(.Call(`_ranger_hshrink_prob`, left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum)) +} + +replace_class_counts <- function(class_counts_old, class_counts_new) { + invisible(.Call(`_ranger_replace_class_counts`, class_counts_old, class_counts_new)) +} + diff --git a/R/hshrink.R b/R/hshrink.R new file mode 100644 index 000000000..02c46f701 --- /dev/null +++ b/R/hshrink.R @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------------- +# This file is part of Ranger. +# +# Ranger is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ranger is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ranger. If not, see . +# +# Written by: +# +# Marvin N. Wright +# Institut fuer Medizinische Biometrie und Statistik +# Universitaet zu Luebeck +# Ratzeburger Allee 160 +# 23562 Luebeck +# Germany +# +# http://www.imbs-luebeck.de +# ------------------------------------------------------------------------------- + + +#' Horizontal shrinkage +#' +#' Apply horizontal shrinkage to a ranger object. +#' Horizontal shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. +#' For details see Agarwal et al. (2022). +#' +#' @param rf ranger object, created with \code{node.stats = TRUE}. +#' @param lambda Non-negative shrinkage parameter. +#' +#' @return The ranger object is modified in-place. +#' +#' @examples +##' @references +##' \itemize{ +##' \item Agarwal, A., Tan, Y.S., Ronen, O., Singh, C. & Yu, B. (2022). Hierarchical Shrinkage: Improving the accuracy and interpretability of tree-based models. Proceedings of the 39th International Conference on Machine Learning, PMLR 162:111-135. +##' } +#' @author Marvin N. Wright +#' @export +hshrink <- function(rf, lambda) { + if (is.null(rf$forest$num.samples.nodes)) { + stop("Horizontal shrinkage needs node statistics, set node.stats=TRUE in ranger() call.") + } + if (lambda < 0) { + stop("Shrinkage parameter lambda has to be non-negative.") + } + + if (rf$treetype == "Regression") { + invisible(lapply(1:rf$num.trees, function(treeID) { + hshrink_regr( + rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]], + rf$forest$num.samples.nodes[[treeID]], rf$forest$node.predictions[[treeID]], + rf$forest$split.values[[treeID]], lambda, 0, 0, 0, 0 + ) + })) + } else if (rf$treetype == "Probability estimation") { + invisible(lapply(1:rf$num.trees, function(treeID) { + # Create temporary class frequency matrix + class_freq <- t(simplify2array(rf$forest$terminal.class.counts[[treeID]])) + + parent_pred <- rep(0, length(rf$forest$class.values)) + cum_sum <- rep(0, length(rf$forest$class.values)) + hshrink_prob( + rf$forest$child.nodeIDs[[treeID]][[1]], rf$forest$child.nodeIDs[[treeID]][[2]], + rf$forest$num.samples.nodes[[treeID]], class_freq, + lambda, 0, 0, parent_pred, cum_sum + ) + + # Assign temporary matrix values back to ranger object + replace_class_counts(rf$forest$terminal.class.counts[[treeID]], class_freq) + })) + } else if (rf$treetype == "Classification") { + stop("To apply horizontal shrinkage to classification forests, use probability=TRUE in the ranger() call.") + } else if (rf$treetype == "Survival") { + stop("Horizontal shrinkage not yet implemented for survival.") + } else { + stop("Unknown treetype.") + } + +} + + diff --git a/man/hshrink.Rd b/man/hshrink.Rd new file mode 100644 index 000000000..df57f1bc7 --- /dev/null +++ b/man/hshrink.Rd @@ -0,0 +1,29 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/hshrink.R +\name{hshrink} +\alias{hshrink} +\title{Horizontal shrinkage} +\usage{ +hshrink(rf, lambda) +} +\arguments{ +\item{rf}{ranger object, created with \code{node.stats = TRUE}.} + +\item{lambda}{Non-negative shrinkage parameter.} +} +\value{ +The ranger object is modified in-place. +} +\description{ +Apply horizontal shrinkage to a ranger object. +Horizontal shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. +For details see Agarwal et al. (2022). +} +\references{ +\itemize{ + \item Agarwal, A., Tan, Y.S., Ronen, O., Singh, C. & Yu, B. (2022). Hierarchical Shrinkage: Improving the accuracy and interpretability of tree-based models. Proceedings of the 39th International Conference on Machine Learning, PMLR 162:111-135. + } +} +\author{ +Marvin N. Wright +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 4cc76e598..ead9a011b 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -95,11 +95,62 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// hshrink_regr +void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions, Rcpp::NumericVector& split_values, double lambda, size_t nodeID, size_t parent_n, double parent_pred, double cum_sum); +RcppExport SEXP _ranger_hshrink_regr(SEXP left_childrenSEXP, SEXP right_childrenSEXP, SEXP num_samples_nodesSEXP, SEXP node_predictionsSEXP, SEXP split_valuesSEXP, SEXP lambdaSEXP, SEXP nodeIDSEXP, SEXP parent_nSEXP, SEXP parent_predSEXP, SEXP cum_sumSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type left_children(left_childrenSEXP); + Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type right_children(right_childrenSEXP); + Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type num_samples_nodes(num_samples_nodesSEXP); + Rcpp::traits::input_parameter< Rcpp::NumericVector& >::type node_predictions(node_predictionsSEXP); + Rcpp::traits::input_parameter< Rcpp::NumericVector& >::type split_values(split_valuesSEXP); + Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP); + Rcpp::traits::input_parameter< size_t >::type nodeID(nodeIDSEXP); + Rcpp::traits::input_parameter< size_t >::type parent_n(parent_nSEXP); + Rcpp::traits::input_parameter< double >::type parent_pred(parent_predSEXP); + Rcpp::traits::input_parameter< double >::type cum_sum(cum_sumSEXP); + hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values, lambda, nodeID, parent_n, parent_pred, cum_sum); + return R_NilValue; +END_RCPP +} +// hshrink_prob +void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericMatrix& class_freq, double lambda, size_t nodeID, size_t parent_n, Rcpp::NumericVector parent_pred, Rcpp::NumericVector cum_sum); +RcppExport SEXP _ranger_hshrink_prob(SEXP left_childrenSEXP, SEXP right_childrenSEXP, SEXP num_samples_nodesSEXP, SEXP class_freqSEXP, SEXP lambdaSEXP, SEXP nodeIDSEXP, SEXP parent_nSEXP, SEXP parent_predSEXP, SEXP cum_sumSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type left_children(left_childrenSEXP); + Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type right_children(right_childrenSEXP); + Rcpp::traits::input_parameter< Rcpp::IntegerVector& >::type num_samples_nodes(num_samples_nodesSEXP); + Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type class_freq(class_freqSEXP); + Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP); + Rcpp::traits::input_parameter< size_t >::type nodeID(nodeIDSEXP); + Rcpp::traits::input_parameter< size_t >::type parent_n(parent_nSEXP); + Rcpp::traits::input_parameter< Rcpp::NumericVector >::type parent_pred(parent_predSEXP); + Rcpp::traits::input_parameter< Rcpp::NumericVector >::type cum_sum(cum_sumSEXP); + hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda, nodeID, parent_n, parent_pred, cum_sum); + return R_NilValue; +END_RCPP +} +// replace_class_counts +void replace_class_counts(Rcpp::List& class_counts_old, Rcpp::NumericMatrix& class_counts_new); +RcppExport SEXP _ranger_replace_class_counts(SEXP class_counts_oldSEXP, SEXP class_counts_newSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< Rcpp::List& >::type class_counts_old(class_counts_oldSEXP); + Rcpp::traits::input_parameter< Rcpp::NumericMatrix& >::type class_counts_new(class_counts_newSEXP); + replace_class_counts(class_counts_old, class_counts_new); + return R_NilValue; +END_RCPP +} static const R_CallMethodDef CallEntries[] = { {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 48}, {"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2}, {"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3}, + {"_ranger_hshrink_regr", (DL_FUNC) &_ranger_hshrink_regr, 10}, + {"_ranger_hshrink_prob", (DL_FUNC) &_ranger_hshrink_prob, 9}, + {"_ranger_replace_class_counts", (DL_FUNC) &_ranger_replace_class_counts, 2}, {NULL, NULL, 0} }; diff --git a/src/utilityRcpp.cpp b/src/utilityRcpp.cpp index dbad4a683..9866702df 100644 --- a/src/utilityRcpp.cpp +++ b/src/utilityRcpp.cpp @@ -89,3 +89,66 @@ Rcpp::NumericMatrix randomObsNode(Rcpp::IntegerMatrix groups, Rcpp::NumericVecto return result; } +// Recursive function for horizontal shrinkage (regression) +//[[Rcpp::export]] +void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, + Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions, + Rcpp::NumericVector& split_values, double lambda, + size_t nodeID, size_t parent_n, double parent_pred, double cum_sum) { + if (nodeID == 0) { + // In the root, just use the prediction + cum_sum = node_predictions[nodeID]; + } else { + // If not root, use shrinkage formula + cum_sum += (node_predictions[nodeID] - parent_pred) / (1 + lambda/parent_n); + } + + if (left_children[nodeID] == 0) { + // If leaf, change node prediction in split_values (used for prediction) + split_values[nodeID] = cum_sum; + } else { + // If not leaf, give weighted prediction to child nodes + hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values, + lambda, left_children[nodeID], num_samples_nodes[nodeID], node_predictions[nodeID], + cum_sum); + hshrink_regr(left_children, right_children, num_samples_nodes, node_predictions, split_values, + lambda, right_children[nodeID], num_samples_nodes[nodeID], node_predictions[nodeID], + cum_sum); + } +} + +// Recursive function for horizontal shrinkage (probability) +//[[Rcpp::export]] +void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, + Rcpp::IntegerVector& num_samples_nodes, + Rcpp::NumericMatrix& class_freq, double lambda, + size_t nodeID, size_t parent_n, Rcpp::NumericVector parent_pred, Rcpp::NumericVector cum_sum) { + + if (nodeID == 0) { + // In the root, just use the prediction + cum_sum = class_freq(nodeID, Rcpp::_); + } else { + // If not root, use shrinkage formula + cum_sum += (class_freq(nodeID, Rcpp::_) - parent_pred) / (1 + lambda/parent_n); + } + + if (left_children[nodeID] == 0) { + // If leaf, change node prediction in split_values (used for prediction) + class_freq(nodeID, Rcpp::_) = cum_sum; + } else { + // If not leaf, give weighted prediction to child nodes + hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda, + left_children[nodeID], num_samples_nodes[nodeID], class_freq(nodeID, Rcpp::_), clone(cum_sum)); + hshrink_prob(left_children, right_children, num_samples_nodes, class_freq, lambda, + right_children[nodeID], num_samples_nodes[nodeID], class_freq(nodeID, Rcpp::_), clone(cum_sum)); + } +} + +// Replace class counts list(vector) with values from matrix +//[[Rcpp::export]] +void replace_class_counts(Rcpp::List& class_counts_old, Rcpp::NumericMatrix& class_counts_new) { + for (size_t i = 0; i < class_counts_old.size(); ++i) { + class_counts_old[i] = class_counts_new(i, Rcpp::_); + } +} + diff --git a/tests/testthat/test_hshrink.R b/tests/testthat/test_hshrink.R new file mode 100644 index 000000000..e6a43fcaa --- /dev/null +++ b/tests/testthat/test_hshrink.R @@ -0,0 +1,68 @@ +## Tests for hierarchical shrinkage + +library(ranger) +context("ranger_hshrink") + +## Tests +test_that("horizontal shrinkage gives an error when node.stats=FALSE", { + rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = FALSE) + expect_error(hshrink(rf, lambda = 5)) +}) + +test_that("horizontal shrinkage does not work for hard classification", { + rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = FALSE) + expect_error(hshrink(rf, lambda = 5)) +}) + +test_that("horizontal shrinkage with lambda=0 doesn't change leafs and prediction, regression", { + rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = TRUE) + split_values_before <- rf$forest$split.values[[1]] + pred_before <- predict(rf, iris)$predictions + hshrink(rf, lambda = 0) + split_values_after <- rf$forest$split.values[[1]] + pred_after <- predict(rf, iris)$predictions + expect_equal(split_values_before, split_values_after) + expect_equal(pred_before, pred_after) +}) + +test_that("horizontal shrinkage with lambda=0 doesn't change leafs and prediction, probability", { + rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE) + class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]]) + pred_before <- predict(rf, iris)$predictions + hshrink(rf, lambda = 0) + class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]]) + pred_after <- predict(rf, iris)$predictions + expect_equal(class_freq_before, class_freq_after) + expect_equal(pred_before, pred_after) +}) + +test_that("horizontal shrinkage with lambda>0 does change leafs and prediction, regression", { + rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, replace = FALSE, sample.fraction = 1, node.stats = TRUE) + split_values_before <- rf$forest$split.values[[1]] + pred_before <- predict(rf, iris)$predictions + split_values_before[1] <- 0 # Modify to create deep copy + hshrink(rf, lambda = 100) + split_values_after <- rf$forest$split.values[[1]] + split_values_after[1] <- 0 # Also modify here + pred_after <- predict(rf, iris)$predictions + expect_false(all(split_values_before == split_values_after)) + + # Shrinkage reduces variance + expect_lt(var(pred_after), var(pred_before)) + +}) + +test_that("horizontal shrinkage with lambda>0 does change leafs and prediction, probability", { + rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE) + class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]]) + pred_before <- predict(rf, iris)$predictions + hshrink(rf, lambda = 100) + class_freq_after <- simplify2array(rf$forest$terminal.class.counts[[1]]) + pred_after <- predict(rf, iris)$predictions + expect_false(all(class_freq_before == class_freq_after)) + + # Shrinkage reduces variance + expect_lt(var(pred_after[, 1]), var(pred_before[, 1])) + expect_lt(var(pred_after[, 2]), var(pred_before[, 2])) + expect_lt(var(pred_after[, 3]), var(pred_before[, 3])) +}) From 22967926cc3ed90418fff3f021893a8ed989fe4a Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Fri, 15 Sep 2023 14:37:45 +0200 Subject: [PATCH 2/2] hierarchical not horizontal --- NEWS.md | 2 +- R/hshrink.R | 12 ++++++------ man/hshrink.Rd | 6 +++--- src/utilityRcpp.cpp | 4 ++-- tests/testthat/test_hshrink.R | 12 ++++++------ 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/NEWS.md b/NEWS.md index 402c6f71b..d18c08f71 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,7 @@ # ranger 0.15.4 * Add node.stats option to save node statistics of all nodes -* Add horizontal shrinkage +* Add hierarchical shrinkage # ranger 0.15.3 * Fix min bucket option in C++ version diff --git a/R/hshrink.R b/R/hshrink.R index 02c46f701..001866f4f 100644 --- a/R/hshrink.R +++ b/R/hshrink.R @@ -27,10 +27,10 @@ # ------------------------------------------------------------------------------- -#' Horizontal shrinkage +#' Hierarchical shrinkage #' -#' Apply horizontal shrinkage to a ranger object. -#' Horizontal shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. +#' Apply hierarchical shrinkage to a ranger object. +#' Hierarchical shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. #' For details see Agarwal et al. (2022). #' #' @param rf ranger object, created with \code{node.stats = TRUE}. @@ -47,7 +47,7 @@ #' @export hshrink <- function(rf, lambda) { if (is.null(rf$forest$num.samples.nodes)) { - stop("Horizontal shrinkage needs node statistics, set node.stats=TRUE in ranger() call.") + stop("Hierarchical shrinkage needs node statistics, set node.stats=TRUE in ranger() call.") } if (lambda < 0) { stop("Shrinkage parameter lambda has to be non-negative.") @@ -78,9 +78,9 @@ hshrink <- function(rf, lambda) { replace_class_counts(rf$forest$terminal.class.counts[[treeID]], class_freq) })) } else if (rf$treetype == "Classification") { - stop("To apply horizontal shrinkage to classification forests, use probability=TRUE in the ranger() call.") + stop("To apply hierarchical shrinkage to classification forests, use probability=TRUE in the ranger() call.") } else if (rf$treetype == "Survival") { - stop("Horizontal shrinkage not yet implemented for survival.") + stop("Hierarchical shrinkage not yet implemented for survival.") } else { stop("Unknown treetype.") } diff --git a/man/hshrink.Rd b/man/hshrink.Rd index df57f1bc7..e48c9a2e9 100644 --- a/man/hshrink.Rd +++ b/man/hshrink.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/hshrink.R \name{hshrink} \alias{hshrink} -\title{Horizontal shrinkage} +\title{Hierarchical shrinkage} \usage{ hshrink(rf, lambda) } @@ -15,8 +15,8 @@ hshrink(rf, lambda) The ranger object is modified in-place. } \description{ -Apply horizontal shrinkage to a ranger object. -Horizontal shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. +Apply hierarchical shrinkage to a ranger object. +Hierarchical shrinkage is a regularization technique that recursively shrinks node predictions towards parent node predictions. For details see Agarwal et al. (2022). } \references{ diff --git a/src/utilityRcpp.cpp b/src/utilityRcpp.cpp index 9866702df..799d57d8d 100644 --- a/src/utilityRcpp.cpp +++ b/src/utilityRcpp.cpp @@ -89,7 +89,7 @@ Rcpp::NumericMatrix randomObsNode(Rcpp::IntegerMatrix groups, Rcpp::NumericVecto return result; } -// Recursive function for horizontal shrinkage (regression) +// Recursive function for hierarchical shrinkage (regression) //[[Rcpp::export]] void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, Rcpp::NumericVector& node_predictions, @@ -117,7 +117,7 @@ void hshrink_regr(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right } } -// Recursive function for horizontal shrinkage (probability) +// Recursive function for hierarchical shrinkage (probability) //[[Rcpp::export]] void hshrink_prob(Rcpp::IntegerVector& left_children, Rcpp::IntegerVector& right_children, Rcpp::IntegerVector& num_samples_nodes, diff --git a/tests/testthat/test_hshrink.R b/tests/testthat/test_hshrink.R index e6a43fcaa..94f2ddbc5 100644 --- a/tests/testthat/test_hshrink.R +++ b/tests/testthat/test_hshrink.R @@ -4,17 +4,17 @@ library(ranger) context("ranger_hshrink") ## Tests -test_that("horizontal shrinkage gives an error when node.stats=FALSE", { +test_that("hierarchical shrinkage gives an error when node.stats=FALSE", { rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = FALSE) expect_error(hshrink(rf, lambda = 5)) }) -test_that("horizontal shrinkage does not work for hard classification", { +test_that("hierarchical shrinkage does not work for hard classification", { rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = FALSE) expect_error(hshrink(rf, lambda = 5)) }) -test_that("horizontal shrinkage with lambda=0 doesn't change leafs and prediction, regression", { +test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, regression", { rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, node.stats = TRUE) split_values_before <- rf$forest$split.values[[1]] pred_before <- predict(rf, iris)$predictions @@ -25,7 +25,7 @@ test_that("horizontal shrinkage with lambda=0 doesn't change leafs and predictio expect_equal(pred_before, pred_after) }) -test_that("horizontal shrinkage with lambda=0 doesn't change leafs and prediction, probability", { +test_that("hierarchical shrinkage with lambda=0 doesn't change leafs and prediction, probability", { rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE) class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]]) pred_before <- predict(rf, iris)$predictions @@ -36,7 +36,7 @@ test_that("horizontal shrinkage with lambda=0 doesn't change leafs and predictio expect_equal(pred_before, pred_after) }) -test_that("horizontal shrinkage with lambda>0 does change leafs and prediction, regression", { +test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, regression", { rf <- ranger(Sepal.Length ~ ., iris, num.trees = 1, replace = FALSE, sample.fraction = 1, node.stats = TRUE) split_values_before <- rf$forest$split.values[[1]] pred_before <- predict(rf, iris)$predictions @@ -52,7 +52,7 @@ test_that("horizontal shrinkage with lambda>0 does change leafs and prediction, }) -test_that("horizontal shrinkage with lambda>0 does change leafs and prediction, probability", { +test_that("hierarchical shrinkage with lambda>0 does change leafs and prediction, probability", { rf <- ranger(Species ~ ., iris, num.trees = 1, node.stats = TRUE, probability = TRUE) class_freq_before <- simplify2array(rf$forest$terminal.class.counts[[1]]) pred_before <- predict(rf, iris)$predictions