diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..28be3a99 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 + +updates: + # Keep dependencies for GitHub Actions up-to-date + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 1d19e954..de55301f 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -29,7 +29,7 @@ jobs: R_KEEP_PKG_SOURCE: yes steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 diff --git a/.github/workflows/cpp-build.yaml b/.github/workflows/cpp-build.yaml index 5a0962f4..b6916d06 100644 --- a/.github/workflows/cpp-build.yaml +++ b/.github/workflows/cpp-build.yaml @@ -11,7 +11,7 @@ jobs: linux: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Build run: | sudo apt-get install cmake @@ -21,7 +21,7 @@ jobs: macos: runs-on: macos-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Build run: | mkdir build && pushd build diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index ed7650c7..29cc0336 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -22,7 +22,7 @@ jobs: permissions: contents: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 @@ -41,7 +41,7 @@ jobs: - name: Deploy to GitHub pages 🚀 if: github.event_name != 'pull_request' - uses: JamesIves/github-pages-deploy-action@v4.4.1 + uses: JamesIves/github-pages-deploy-action@v4.6.0 with: clean: false branch: gh-pages diff --git a/DESCRIPTION b/DESCRIPTION index 30d979b2..3e2b79b9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,7 +19,7 @@ Suggests: survival, testthat Encoding: UTF-8 -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 URL: http://imbs-hl.github.io/ranger/, https://github.com/imbs-hl/ranger BugReports: https://github.com/imbs-hl/ranger/issues diff --git a/NEWS.md b/NEWS.md index 5acbe06a..00b3a107 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,8 @@ # ranger 0.16.1 +* Set num.threads=2 as default; respect environment variables and options * Add hierarchical shrinkage +* Allow vector min.node.size and min.bucket for class-specific limits # ranger 0.16.0 * New CRAN version diff --git a/R/onAttach.R b/R/onAttach.R new file mode 100644 index 00000000..61b69dda --- /dev/null +++ b/R/onAttach.R @@ -0,0 +1,22 @@ + +.onAttach = function(libname, pkgname) { + if (!interactive()) { + return() + } + + threads_env <- Sys.getenv("R_RANGER_NUM_THREADS") + threads_option1 <- getOption("ranger.num.threads") + threads_option2 <- getOption("Ncpus") + + if (threads_env != "") { + thread_string <- paste(threads_env, "threads as set by environment variable R_RANGER_NUM_THREADS. Can be overwritten with num.threads.") + } else if (!is.null(threads_option1)) { + thread_string <- paste(threads_option1, "threads as set by options(ranger.num.threads = N). Can be overwritten with num.threads.") + } else if (!is.null(threads_option2)) { + thread_string <- paste(threads_option2, "threads as set by options(Ncpus = N). Can be overwritten with num.threads.") + } else { + thread_string <- "2 threads (default). Change with num.threads in ranger() and predict(), options(Ncpus = N), options(ranger.num.threads = N) or environment variable R_RANGER_NUM_THREADS." + } + + packageStartupMessage(paste("ranger", packageVersion("ranger"), "using", thread_string)) +} diff --git a/R/predict.R b/R/predict.R index 82599ab0..d11c453e 100644 --- a/R/predict.R +++ b/R/predict.R @@ -36,6 +36,9 @@ ##' ##' For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics. ##' To retrieve the corresponding factor levels, use \code{rf$forest$levels}, if \code{rf} is the ranger object. +##' +##' By default, ranger uses 2 threads. The default can be changed with: (1) \code{num.threads} in ranger/predict call, (2) environment variable +##' R_RANGER_NUM_THREADS, (3) \code{options(ranger.num.threads = N)}, (4) \code{options(Ncpus = N)}, with precedence in that order. ##' ##' @title Ranger prediction ##' @param object Ranger \code{ranger.forest} object. @@ -45,7 +48,7 @@ ##' @param type Type of prediction. One of 'response', 'se', 'terminalNodes', 'quantiles' with default 'response'. See below for details. ##' @param se.method Method to compute standard errors. One of 'jack', 'infjack' with default 'infjack'. Only applicable if type = 'se'. See below for details. ##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed. The seed is used in case of ties in classification mode. -##' @param num.threads Number of threads. Default is number of CPUs available. +##' @param num.threads Number of threads. Use 0 for all available cores. Default is 2 if not set by options/environment variables (see below). ##' @param verbose Verbose output on or off. ##' @param inbag.counts Number of times the observations are in-bag in the trees. ##' @param ... further arguments passed to or from other methods. @@ -193,7 +196,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, ## Num threads ## Default 0 -> detect from system in C++. if (is.null(num.threads)) { - num.threads = 0 + num.threads <- as.integer(Sys.getenv("R_RANGER_NUM_THREADS", getOption("ranger.num.threads", getOption("Ncpus", 2L)))) } else if (!is.numeric(num.threads) || num.threads < 0) { stop("Error: Invalid value for num.threads") } @@ -433,6 +436,9 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, ##' ##' For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics. ##' To retrieve the corresponding factor levels, use \code{rf$forest$levels}, if \code{rf} is the ranger object. +##' +##' By default, ranger uses 2 threads. The default can be changed with: (1) \code{num.threads} in ranger/predict call, (2) environment variable +##' R_RANGER_NUM_THREADS, (3) \code{options(ranger.num.threads = N)}, (4) \code{options(Ncpus = N)}, with precedence in that order. ##' ##' @title Ranger prediction ##' @param object Ranger \code{ranger} object. @@ -444,7 +450,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, ##' @param quantiles Vector of quantiles for quantile prediction. Set \code{type = 'quantiles'} to use. ##' @param what User specified function for quantile prediction used instead of \code{quantile}. Must return numeric vector, see examples. ##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed. The seed is used in case of ties in classification mode. -##' @param num.threads Number of threads. Default is number of CPUs available. +##' @param num.threads Number of threads. Use 0 for all available cores. Default is 2 if not set by options/environment variables (see below). ##' @param verbose Verbose output on or off. ##' @param ... further arguments passed to or from other methods. ##' @return Object of class \code{ranger.prediction} with elements diff --git a/R/ranger.R b/R/ranger.R index 6d56d4d4..1cc11ff7 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -96,10 +96,10 @@ ##' To use only the SNPs without sex or other covariates from the phenotype file, use \code{0} on the right hand side of the formula. ##' Note that missing values are treated as an extra category while splitting. ##' -##' See \url{https://github.com/imbs-hl/ranger} for the development version. +##' By default, ranger uses 2 threads. The default can be changed with: (1) \code{num.threads} in ranger/predict call, (2) environment variable +##' R_RANGER_NUM_THREADS, (3) \code{options(ranger.num.threads = N)}, (4) \code{options(Ncpus = N)}, with precedence in that order. ##' -##' With recent R versions, multithreading on Windows platforms should just work. -##' If you compile yourself, the new RTools toolchain is required. +##' See \url{https://github.com/imbs-hl/ranger} for the development version. ##' ##' @title Ranger ##' @param formula Object of class \code{formula} or \code{character} describing the model to fit. Interaction terms supported only for numerical variables. @@ -109,8 +109,8 @@ ##' @param importance Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression and the sum of test statistics (see \code{splitrule}) for survival. ##' @param write.forest Save \code{ranger.forest} object, required for prediction. Set to \code{FALSE} to reduce memory usage if no prediction intended. ##' @param probability Grow a probability forest as in Malley et al. (2012). -##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. -##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. +##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. For classification, this can be a vector of class-specific values. +##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. For classification, this can be a vector of class-specific values. ##' @param max.depth Maximal tree depth. A value of NULL or 0 (the default) corresponds to unlimited depth, 1 to tree stumps (1 split per tree). ##' @param replace Sample with replacement. ##' @param sample.fraction Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. For classification, this can be a vector of class-specific values. @@ -133,7 +133,7 @@ ##' @param quantreg Prepare quantile prediction as in quantile regression forests (Meinshausen 2006). Regression only. Set \code{keep.inbag = TRUE} to prepare out-of-bag quantile prediction. ##' @param time.interest Time points of interest (survival only). Can be \code{NULL} (default, use all observed time points), a vector of time points or a single number to use as many time points (grid over observed time points). ##' @param oob.error Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests. -##' @param num.threads Number of threads. Default is number of CPUs available. +##' @param num.threads Number of threads. Use 0 for all available cores. Default is 2 if not set by options/environment variables (see below). ##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems. ##' @param verbose Show computation status and estimated runtime. ##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction, number of observations and split statistics for each node. @@ -359,6 +359,15 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, stop("Error: Unsupported type of dependent variable.") } + ## Number of levels + if (treetype %in% c(1, 9)) { + if (is.factor(y)) { + num_levels <- nlevels(y) + } else { + num_levels <- length(unique(y)) + } + } + ## Quantile prediction only for regression if (quantreg && treetype != 3) { stop("Error: Quantile prediction implemented only for regression outcomes.") @@ -514,7 +523,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, ## Num threads ## Default 0 -> detect from system in C++. if (is.null(num.threads)) { - num.threads = 0 + num.threads <- as.integer(Sys.getenv("R_RANGER_NUM_THREADS", getOption("ranger.num.threads", getOption("Ncpus", 2L)))) } else if (!is.numeric(num.threads) || num.threads < 0) { stop("Error: Invalid value for num.threads") } @@ -522,16 +531,46 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, ## Minimum node size if (is.null(min.node.size)) { min.node.size <- 0 - } else if (!is.numeric(min.node.size) || min.node.size < 0) { - stop("Error: Invalid value for min.node.size") + } else if (!is.numeric(min.node.size)) { + stop("Error: Invalid value for min.node.size.") + } + if (length(min.node.size) > 1) { + if (!(treetype %in% c(1, 9))) { + stop("Error: Invalid value for min.node.size. Vector values only valid for classification forests.") + } + if (any(min.node.size < 0)) { + stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.") + } + if (length(min.node.size) != num_levels) { + stop("Error: Invalid value for min.node.size Expecting ", num_levels, " values, provided ", length(min.node.size), ".") + } + } else { + if (min.node.size < 0) { + stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.") + } } ## Minimum bucket size if (is.null(min.bucket)) { min.bucket <- 0 - } else if (!is.numeric(min.bucket) || min.bucket < 0) { + } else if (!is.numeric(min.bucket)) { stop("Error: Invalid value for min.bucket") } + if (length(min.bucket) > 1) { + if (!(treetype %in% c(1, 9))) { + stop("Error: Invalid value for min.bucket Vector values only valid for classification forests.") + } + if (any(min.bucket < 0)) { + stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.") + } + if (length(min.bucket) != num_levels) { + stop("Error: Invalid value for min.bucket Expecting ", num_levels, " values, provided ", length(min.bucket), ".") + } + } else { + if (min.bucket < 0) { + stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.") + } + } ## Tree depth if (is.null(max.depth)) { @@ -554,8 +593,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, if (sum(sample.fraction) <= 0) { stop("Error: Invalid value for sample.fraction. Sum of values must be >0.") } - if (length(sample.fraction) != nlevels(y)) { - stop("Error: Invalid value for sample.fraction. Expecting ", nlevels(y), " values, provided ", length(sample.fraction), ".") + if (length(sample.fraction) != num_levels) { + stop("Error: Invalid value for sample.fraction. Expecting ", num_levels, " values, provided ", length(sample.fraction), ".") } if (!replace & any(sample.fraction * length(y) > table(y))) { idx <- which(sample.fraction * length(y) > table(y))[1] @@ -1037,6 +1076,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, result$dependent.variable.name <- dependent.variable.name result$status.variable.name <- status.variable.name + ## Save max.depth + if (!is.null(max.depth)) { + result$max.depth <- max.depth + } + class(result) <- "ranger" ## Prepare quantile prediction diff --git a/man/predict.ranger.Rd b/man/predict.ranger.Rd index 362befca..2f9c63ac 100644 --- a/man/predict.ranger.Rd +++ b/man/predict.ranger.Rd @@ -38,7 +38,7 @@ \item{seed}{Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed. The seed is used in case of ties in classification mode.} -\item{num.threads}{Number of threads. Default is number of CPUs available.} +\item{num.threads}{Number of threads. Use 0 for all available cores. Default is 2 if not set by options/environment variables (see below).} \item{verbose}{Verbose output on or off.} @@ -70,6 +70,9 @@ If \code{type = 'se'} is selected, the method to estimate the variances can be c For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics. To retrieve the corresponding factor levels, use \code{rf$forest$levels}, if \code{rf} is the ranger object. + +By default, ranger uses 2 threads. The default can be changed with: (1) \code{num.threads} in ranger/predict call, (2) environment variable +R_RANGER_NUM_THREADS, (3) \code{options(ranger.num.threads = N)}, (4) \code{options(Ncpus = N)}, with precedence in that order. } \examples{ ## Classification forest diff --git a/man/predict.ranger.forest.Rd b/man/predict.ranger.forest.Rd index ba018b0e..805effda 100644 --- a/man/predict.ranger.forest.Rd +++ b/man/predict.ranger.forest.Rd @@ -33,7 +33,7 @@ \item{seed}{Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed. The seed is used in case of ties in classification mode.} -\item{num.threads}{Number of threads. Default is number of CPUs available.} +\item{num.threads}{Number of threads. Use 0 for all available cores. Default is 2 if not set by options/environment variables (see below).} \item{verbose}{Verbose output on or off.} @@ -66,6 +66,9 @@ If \code{type = 'se'} is selected, the method to estimate the variances can be c For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics. To retrieve the corresponding factor levels, use \code{rf$forest$levels}, if \code{rf} is the ranger object. + +By default, ranger uses 2 threads. The default can be changed with: (1) \code{num.threads} in ranger/predict call, (2) environment variable +R_RANGER_NUM_THREADS, (3) \code{options(ranger.num.threads = N)}, (4) \code{options(Ncpus = N)}, with precedence in that order. } \references{ \itemize{ diff --git a/man/ranger.Rd b/man/ranger.Rd index 61c6e5df..9f8036b7 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -64,9 +64,9 @@ ranger( \item{probability}{Grow a probability forest as in Malley et al. (2012).} -\item{min.node.size}{Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability.} +\item{min.node.size}{Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. For classification, this can be a vector of class-specific values.} -\item{min.bucket}{Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types.} +\item{min.bucket}{Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. For classification, this can be a vector of class-specific values.} \item{max.depth}{Maximal tree depth. A value of NULL or 0 (the default) corresponds to unlimited depth, 1 to tree stumps (1 split per tree).} @@ -112,7 +112,7 @@ ranger( \item{oob.error}{Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests.} -\item{num.threads}{Number of threads. Default is number of CPUs available.} +\item{num.threads}{Number of threads. Use 0 for all available cores. Default is 2 if not set by options/environment variables (see below).} \item{save.memory}{Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems.} @@ -230,10 +230,10 @@ All SNPs in the \code{GenABEL} object will be used for splitting. To use only the SNPs without sex or other covariates from the phenotype file, use \code{0} on the right hand side of the formula. Note that missing values are treated as an extra category while splitting. -See \url{https://github.com/imbs-hl/ranger} for the development version. +By default, ranger uses 2 threads. The default can be changed with: (1) \code{num.threads} in ranger/predict call, (2) environment variable +R_RANGER_NUM_THREADS, (3) \code{options(ranger.num.threads = N)}, (4) \code{options(Ncpus = N)}, with precedence in that order. -With recent R versions, multithreading on Windows platforms should just work. -If you compile yourself, the new RTools toolchain is required. +See \url{https://github.com/imbs-hl/ranger} for the development version. } \examples{ ## Classification forest with default settings diff --git a/src/Forest.cpp b/src/Forest.cpp index 8c7a4242..7100cca3 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -27,7 +27,7 @@ namespace ranger { Forest::Forest() : - verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size(0), min_bucket(0), num_independent_variables(0), seed(0), num_samples( + verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size({0}), min_bucket({0}), num_independent_variables(0), seed(0), num_samples( 0), prediction_mode(false), memory_mode(MEM_DOUBLE), sample_with_replacement(true), memory_saving_splitting( false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction( { 1 }), holdout( false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth( @@ -62,6 +62,9 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode if (!load_forest_filename.empty()) { prediction_mode = true; } + + std::vector min_node_size_vector = { min_node_size }; + std::vector min_bucket_vector = { min_bucket }; // Sample fraction default and convert to vector if (sample_fraction == 0) { @@ -79,7 +82,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode // Call other init function init(loadDataFromFile(input_file), mtry, output_prefix, num_trees, seed, num_threads, importance_mode, - min_node_size, min_bucket, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting, + min_node_size_vector, min_bucket_vector, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting, splitrule, predict_all, sample_fraction_vector, alpha, minprop, holdout, prediction_type, num_random_splits, false, max_depth, regularization_factor, regularization_usedepth, false); @@ -133,7 +136,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode // #nocov end void Forest::initR(std::unique_ptr input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed, - uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, const std::vector& always_split_variable_names, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, std::vector& case_weights, @@ -178,7 +181,7 @@ void Forest::initR(std::unique_ptr input_data, uint mtry, uint num_trees, } void Forest::init(std::unique_ptr input_data, uint mtry, std::string output_prefix, - uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector& sample_fraction, double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, bool order_snps, @@ -323,7 +326,7 @@ void Forest::writeOutput() { *verbose_out << "Sample size: " << num_samples << std::endl; *verbose_out << "Number of independent variables: " << num_independent_variables << std::endl; *verbose_out << "Mtry: " << mtry << std::endl; - *verbose_out << "Target node size: " << min_node_size << std::endl; + *verbose_out << "Target node size: " << min_node_size[0] << std::endl; *verbose_out << "Variable importance mode: " << importance_mode << std::endl; *verbose_out << "Memory mode: " << memory_mode << std::endl; *verbose_out << "Seed: " << seed << std::endl; @@ -473,7 +476,7 @@ void Forest::grow() { } trees[i]->init(data.get(), mtry, num_samples, tree_seed, &deterministic_varIDs, tree_split_select_weights, - importance_mode, min_node_size, min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights, + importance_mode, &min_node_size, &min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights, tree_manual_inbag, keep_inbag, &sample_fraction, alpha, minprop, holdout, num_random_splits, max_depth, ®ularization_factor, regularization_usedepth, &split_varIDs_used, save_node_stats); } diff --git a/src/Forest.h b/src/Forest.h index 5b297202..e9279154 100644 --- a/src/Forest.h +++ b/src/Forest.h @@ -48,7 +48,7 @@ class Forest { bool holdout, PredictionType prediction_type, uint num_random_splits, uint max_depth, const std::vector& regularization_factor, bool regularization_usedepth); void initR(std::unique_ptr input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed, - uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, const std::vector& always_split_variable_names, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, @@ -58,7 +58,7 @@ class Forest { const std::vector& regularization_factor, bool regularization_usedepth, bool node_stats); void init(std::unique_ptr input_data, uint mtry, std::string output_prefix, - uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector& sample_fraction, double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, @@ -119,10 +119,10 @@ class Forest { uint getMtry() const { return mtry; } - uint getMinNodeSize() const { + const std::vector& getMinNodeSize() const { return min_node_size; } - uint getMinBucket() const { + const std::vector& getMinBucket() const { return min_bucket; } size_t getNumIndependentVariables() const { @@ -209,8 +209,8 @@ class Forest { std::vector dependent_variable_names; // time,status for survival size_t num_trees; uint mtry; - uint min_node_size; - uint min_bucket; + std::vector min_node_size; + std::vector min_bucket; size_t num_independent_variables; uint seed; size_t num_samples; diff --git a/src/ForestClassification.cpp b/src/ForestClassification.cpp index bb6861f7..1f26aeb3 100644 --- a/src/ForestClassification.cpp +++ b/src/ForestClassification.cpp @@ -54,13 +54,13 @@ void ForestClassification::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET; } // Create class_values and response_classIDs diff --git a/src/ForestProbability.cpp b/src/ForestProbability.cpp index 40922554..817b2de5 100644 --- a/src/ForestProbability.cpp +++ b/src/ForestProbability.cpp @@ -59,13 +59,13 @@ void ForestProbability::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_PROBABILITY; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_PROBABILITY; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET; } // Create class_values and response_classIDs diff --git a/src/ForestRegression.cpp b/src/ForestRegression.cpp index 7c1bb326..6328ac20 100644 --- a/src/ForestRegression.cpp +++ b/src/ForestRegression.cpp @@ -48,13 +48,13 @@ void ForestRegression::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_REGRESSION; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_REGRESSION; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET; } // Error if beta splitrule used with data outside of [0,1] diff --git a/src/ForestSurvival.cpp b/src/ForestSurvival.cpp index 3b88e312..9a31f3da 100644 --- a/src/ForestSurvival.cpp +++ b/src/ForestSurvival.cpp @@ -98,13 +98,13 @@ void ForestSurvival::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_SURVIVAL; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_SURVIVAL; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET_SURVIVAL; } // Sort data if extratrees and not memory saving mode diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index b34f34a6..0c9ac2bf 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,7 +13,7 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // rangerCpp -Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, bool node_stats, std::vector& time_interest, bool use_time_interest); +Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, bool node_stats, std::vector& time_interest, bool use_time_interest); RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP node_statsSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; @@ -29,8 +29,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< uint >::type num_threads(num_threadsSEXP); Rcpp::traits::input_parameter< bool >::type write_forest(write_forestSEXP); Rcpp::traits::input_parameter< uint >::type importance_mode_r(importance_mode_rSEXP); - Rcpp::traits::input_parameter< uint >::type min_node_size(min_node_sizeSEXP); - Rcpp::traits::input_parameter< uint >::type min_bucket(min_bucketSEXP); + Rcpp::traits::input_parameter< std::vector& >::type min_node_size(min_node_sizeSEXP); + Rcpp::traits::input_parameter< std::vector& >::type min_bucket(min_bucketSEXP); Rcpp::traits::input_parameter< std::vector>& >::type split_select_weights(split_select_weightsSEXP); Rcpp::traits::input_parameter< bool >::type use_split_select_weights(use_split_select_weightsSEXP); Rcpp::traits::input_parameter< std::vector& >::type always_split_variable_names(always_split_variable_namesSEXP); diff --git a/src/Tree.cpp b/src/Tree.cpp index 57d3dfbd..fd97f686 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -39,7 +39,7 @@ Tree::Tree(std::vector>& child_nodeIDs, std::vector& } void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, - std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + std::vector* split_select_weights, ImportanceMode importance_mode, std::vector* min_node_size, std::vector* min_bucket, bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, std::vector* manual_inbag, bool keep_inbag, std::vector* sample_fraction, double alpha, double minprop, bool holdout, uint num_random_splits, uint max_depth, std::vector* regularization_factor, @@ -90,7 +90,7 @@ void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std: void Tree::grow(std::vector* variable_importance) { // Allocate memory for tree growing allocateMemory(); - + this->variable_importance = variable_importance; // Bootstrap, dependent if weighted or not and with or without replacement @@ -307,7 +307,7 @@ void Tree::createPossibleSplitVarSubset(std::vector& result) { } bool Tree::splitNode(size_t nodeID) { - + // Select random subset of variables to possibly split at std::vector possible_split_varIDs; createPossibleSplitVarSubset(possible_split_varIDs); diff --git a/src/Tree.h b/src/Tree.h index 101c300d..17b98ff0 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -36,7 +36,7 @@ class Tree { Tree& operator=(const Tree&) = delete; void init(const Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, - std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + std::vector* split_select_weights, ImportanceMode importance_mode, std::vector* min_node_size, std::vector* min_bucket, bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, std::vector* manual_inbag, bool keep_inbag, std::vector* sample_fraction, double alpha, double minprop, bool holdout, uint num_random_splits, @@ -166,10 +166,10 @@ class Tree { size_t num_samples_oob; // Minimum node size to split, nodes of smaller size can be produced - uint min_node_size; + std::vector* min_node_size; // Minimum bucket size, minimum number of samples in each node - uint min_bucket; + std::vector* min_bucket; // Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable // Deterministic variables are always selected diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index bbb3b581..a2129777 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -85,7 +85,7 @@ bool TreeClassification::splitNodeInternal(size_t nodeID, std::vector& p } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if ((min_node_size->size() == 1 && num_samples_node <= (*min_node_size)[0]) || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { split_values[nodeID] = estimate(nodeID); return true; } @@ -166,34 +166,54 @@ bool TreeClassification::findBestSplit(size_t nodeID, std::vector& possi uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } + + // Stop if class-wise minimal node size reached + if (min_node_size->size() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } + + // Stop early if no split posssible + if (min_bucket->size() == 1) { + if (num_samples_node < 2 * (*min_bucket)[0]) { + return true; + } + } else { + uint sum_min_bucket = 0; + for (size_t j = 0; j < num_classes; ++j) { + sum_min_bucket += (*min_bucket)[j]; + } + if (num_samples_node < sum_min_bucket) { + return true; + } + } -// Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { - - // For all possible split variables - for (auto& varID : possible_split_varIDs) { - // Find best split value, if ordered consider all values as split values, else all 2-partitions - if (data->isOrderedVariable(varID)) { + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { - // Use memory saving method if option set - if (memory_saving_splitting) { + // Use memory saving method if option set + if (memory_saving_splitting) { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + // Use faster method for both cases + double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); + if (q < Q_THRESHOLD) { findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, best_decrease); } else { - // Use faster method for both cases - double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); - if (q < Q_THRESHOLD) { - findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); - } else { - findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); - } + findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); } - } else { - findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); } + } else { + findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); } } @@ -283,7 +303,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -317,6 +337,21 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -375,7 +410,7 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -409,6 +444,21 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -486,7 +536,7 @@ void TreeClassification::findBestSplitValueUnordered(size_t nodeID, size_t varID size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -516,6 +566,21 @@ void TreeClassification::findBestSplitValueUnordered(size_t nodeID, size_t varID // Decrease of impurity decrease = sum_left / (double) n_left + sum_right / (double) n_right; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -544,20 +609,40 @@ bool TreeClassification::findBestSplitExtraTrees(size_t nodeID, std::vectorsize() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (min_bucket->size() == 1) { + if (num_samples_node < 2 * (*min_bucket)[0]) { + return true; + } + } else { + uint sum_min_bucket = 0; + for (size_t j = 0; j < num_classes; ++j) { + sum_min_bucket += (*min_bucket)[j]; + } + if (num_samples_node < sum_min_bucket) { + return true; + } + } - // For all possible split variables - for (auto& varID : possible_split_varIDs) { - // Find best split value, if ordered consider all values as split values, else all 2-partitions - if (data->isOrderedVariable(varID)) { - findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); - } else { - findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, - best_varID, best_decrease); - } + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, + best_varID, best_decrease); } } @@ -657,7 +742,7 @@ void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varI } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0])) { continue; } @@ -671,6 +756,21 @@ void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varI sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; @@ -768,7 +868,7 @@ void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, si size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -782,6 +882,21 @@ void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, si sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right; diff --git a/src/TreeProbability.cpp b/src/TreeProbability.cpp index 53d630b7..c92a5713 100644 --- a/src/TreeProbability.cpp +++ b/src/TreeProbability.cpp @@ -89,7 +89,7 @@ bool TreeProbability::splitNodeInternal(size_t nodeID, std::vector& poss } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if ((min_node_size->size() == 1 && num_samples_node <= (*min_node_size)[0]) || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { addToTerminalNodes(nodeID); } @@ -170,34 +170,54 @@ bool TreeProbability::findBestSplit(size_t nodeID, std::vector& possible uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } - + + // Stop if class-wise minimal node size reached + if (min_node_size->size() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } + // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (min_bucket->size() == 1) { + if (num_samples_node < 2 * (*min_bucket)[0]) { + return true; + } + } else { + uint sum_min_bucket = 0; + for (size_t j = 0; j < num_classes; ++j) { + sum_min_bucket += (*min_bucket)[j]; + } + if (num_samples_node < sum_min_bucket) { + return true; + } + } - // For all possible split variables - for (auto& varID : possible_split_varIDs) { - // Find best split value, if ordered consider all values as split values, else all 2-partitions - if (data->isOrderedVariable(varID)) { + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { - // Use memory saving method if option set - if (memory_saving_splitting) { + // Use memory saving method if option set + if (memory_saving_splitting) { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + // Use faster method for both cases + double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); + if (q < Q_THRESHOLD) { findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, best_decrease); } else { - // Use faster method for both cases - double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); - if (q < Q_THRESHOLD) { - findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); - } else { - findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); - } + findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); } - } else { - findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); } + } else { + findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); } } @@ -287,7 +307,7 @@ void TreeProbability::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -321,6 +341,21 @@ void TreeProbability::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -379,7 +414,7 @@ void TreeProbability::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -413,6 +448,21 @@ void TreeProbability::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -490,7 +540,7 @@ void TreeProbability::findBestSplitValueUnordered(size_t nodeID, size_t varID, s size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -520,6 +570,21 @@ void TreeProbability::findBestSplitValueUnordered(size_t nodeID, size_t varID, s // Decrease of impurity decrease = sum_left / (double) n_left + sum_right / (double) n_right; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -548,20 +613,40 @@ bool TreeProbability::findBestSplitExtraTrees(size_t nodeID, std::vector uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } - + + // Stop if class-wise minimal node size reached + if (min_node_size->size() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } + // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (min_bucket->size() == 1) { + if (num_samples_node < 2 * (*min_bucket)[0]) { + return true; + } + } else { + uint sum_min_bucket = 0; + for (size_t j = 0; j < num_classes; ++j) { + sum_min_bucket += (*min_bucket)[j]; + } + if (num_samples_node < sum_min_bucket) { + return true; + } + } - // For all possible split variables - for (auto& varID : possible_split_varIDs) { - // Find best split value, if ordered consider all values as split values, else all 2-partitions - if (data->isOrderedVariable(varID)) { - findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); - } else { - findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, - best_varID, best_decrease); - } + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, + best_varID, best_decrease); } } @@ -661,7 +746,7 @@ void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0])) { continue; } @@ -675,6 +760,21 @@ void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; @@ -772,7 +872,7 @@ void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_ size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -786,6 +886,21 @@ void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_ sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right; diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp index c272695b..ec59528f 100644 --- a/src/TreeRegression.cpp +++ b/src/TreeRegression.cpp @@ -68,7 +68,7 @@ bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector& possi } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { split_values[nodeID] = estimate(nodeID); return true; } @@ -150,7 +150,7 @@ bool TreeRegression::findBestSplit(size_t nodeID, std::vector& possible_ } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -261,7 +261,7 @@ void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, doubl } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -323,7 +323,7 @@ void TreeRegression::findBestSplitValueLargeQ(size_t nodeID, size_t varID, doubl } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -405,7 +405,7 @@ void TreeRegression::findBestSplitValueUnordered(size_t nodeID, size_t varID, do size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -548,7 +548,7 @@ bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector& } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -658,7 +658,7 @@ void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, d } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0]) { continue; } @@ -758,7 +758,7 @@ void TreeRegression::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -793,7 +793,7 @@ bool TreeRegression::findBestSplitBeta(size_t nodeID, std::vector& possi } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables find best split value for (auto& varID : possible_split_varIDs) { @@ -886,7 +886,7 @@ void TreeRegression::findBestSplitValueBeta(size_t nodeID, size_t varID, double } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0]) { continue; } diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index 1c60ba8b..7ae94768 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -140,7 +140,7 @@ bool TreeSurvival::findBestSplit(size_t nodeID, std::vector& possible_sp } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { computeSurvival(nodeID); } @@ -148,7 +148,7 @@ bool TreeSurvival::findBestSplit(size_t nodeID, std::vector& possible_sp } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -200,7 +200,7 @@ bool TreeSurvival::findBestSplitMaxstat(size_t nodeID, std::vector& poss size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID]; // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { computeDeathCounts(nodeID); computeSurvival(nodeID); @@ -414,7 +414,7 @@ void TreeSurvival::findBestSplitValueLogRank(size_t nodeID, size_t varID, double // Stop if minimal bucket size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child[i]; - if (num_samples_right_child[i] < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child[i] < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } @@ -520,7 +520,7 @@ void TreeSurvival::findBestSplitValueLogRankUnordered(size_t nodeID, size_t varI // Stop if minimal bucket size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child; - if (num_samples_right_child < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } @@ -611,7 +611,7 @@ void TreeSurvival::findBestSplitValueAUC(size_t nodeID, size_t varID, double& be for (size_t i = 0; i < num_splits; ++i) { // Do not consider this split point if fewer than min_bucket samples in one node size_t num_samples_right_child = num_node_samples - num_samples_left_child[i]; - if (num_samples_left_child[i] < min_bucket || num_samples_right_child < min_bucket) { + if (num_samples_left_child[i] < (*min_bucket)[0] || num_samples_right_child < (*min_bucket)[0]) { continue; } else { double auc = fabs((num_count[i] / 2) / num_total[i] - 0.5); @@ -711,7 +711,7 @@ bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector& p } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { computeSurvival(nodeID); } @@ -719,7 +719,7 @@ bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector& p } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -805,7 +805,7 @@ void TreeSurvival::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, dou // Stop if minimal node size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child[i]; - if (num_samples_right_child[i] < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child[i] < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } @@ -934,7 +934,7 @@ void TreeSurvival::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t v // Stop if minimal node size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child; - if (num_samples_right_child < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index c8c4fed2..66e34c0e 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -50,7 +50,7 @@ using namespace ranger; // [[Rcpp::export]] Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, - bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, + bool write_forest, uint importance_mode_r, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, diff --git a/tests/testthat/test_classification.R b/tests/testthat/test_classification.R index 0d015c9d..4690afe5 100644 --- a/tests/testthat/test_classification.R +++ b/tests/testthat/test_classification.R @@ -10,9 +10,9 @@ rg.class <- ranger(Species ~ ., data = iris) rg.mat <- ranger(dependent.variable.name = "Species", data = dat, classification = TRUE) ## Basic tests (for all random forests equal) -test_that("classification result is of class ranger with 15 elements", { +test_that("classification result is of class ranger with 16 elements", { expect_is(rg.class, "ranger") - expect_equal(length(rg.class), 15) + expect_equal(length(rg.class), 16) }) test_that("classification prediction returns factor", { diff --git a/tests/testthat/test_print.R b/tests/testthat/test_print.R index 3ca91b4a..8563b1e3 100644 --- a/tests/testthat/test_print.R +++ b/tests/testthat/test_print.R @@ -16,7 +16,7 @@ expect_that(print(rf$forest), prints_text("Ranger forest object")) expect_that(print(predict(rf, iris)), prints_text("Ranger prediction")) ## Test str ranger function -expect_that(str(rf), prints_text("List of 15")) +expect_that(str(rf), prints_text("List of 16")) ## Test str forest function expect_that(str(rf$forest), prints_text("List of 9")) diff --git a/tests/testthat/test_ranger.R b/tests/testthat/test_ranger.R index 0f2c0118..40c788c0 100644 --- a/tests/testthat/test_ranger.R +++ b/tests/testthat/test_ranger.R @@ -86,6 +86,14 @@ test_that("Inbag counts match sample fraction, classification", { expect_equal(unique(colSums(inbag[dat$Species == "setosa", ])), 15) expect_equal(unique(colSums(inbag[dat$Species == "versicolor", ])), 30) expect_equal(unique(colSums(inbag[dat$Species == "virginica", ])), 45) + + ## No factor outcome + rf <- ranger(Species ~ ., data.matrix(iris), num.trees = 5, sample.fraction = c(0.2, 0.3, 0.4), + replace = TRUE, keep.inbag = TRUE, classification = TRUE) + inbag <- do.call(cbind, rf$inbag.counts) + expect_equal(unique(colSums(inbag[iris$Species == "setosa", ])), 30) + expect_equal(unique(colSums(inbag[iris$Species == "versicolor", ])), 45) + expect_equal(unique(colSums(inbag[iris$Species == "virginica", ])), 60) }) test_that("Inbag counts match sample fraction, probability", { @@ -104,6 +112,14 @@ test_that("Inbag counts match sample fraction, probability", { expect_equal(unique(colSums(inbag[1:50, ])), 15) expect_equal(unique(colSums(inbag[51:100, ])), 30) expect_equal(unique(colSums(inbag[101:150, ])), 45) + + ## No factor outcome + rf <- ranger(Species ~ ., data.matrix(iris), num.trees = 5, sample.fraction = c(0.2, 0.3, 0.4), + replace = TRUE, keep.inbag = TRUE, probability = TRUE) + inbag <- do.call(cbind, rf$inbag.counts) + expect_equal(unique(colSums(inbag[1:50, ])), 30) + expect_equal(unique(colSums(inbag[51:100, ])), 45) + expect_equal(unique(colSums(inbag[101:150, ])), 60) }) test_that("as.factor() in formula works", { @@ -382,3 +398,85 @@ test_that("min.bucket creates nodes of correct size", { })) expect_gte(smallest_node, min.bucket) }) + +test_that("Vector min.bucket creates nodes of correct size", { + + # Size 2,3,4 + rf <- ranger(Species ~ ., iris, num.trees = 5, replace = FALSE, + min.bucket = c(2, 3, 4), keep.inbag = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(c(2, 3, 4), ncol = 5, nrow = 3))) + + # Size 4,3,2 + rf <- ranger(Species ~ ., iris, num.trees = 5, replace = FALSE, + min.bucket = c(4, 3, 2), keep.inbag = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(c(4, 3, 2), ncol = 5, nrow = 3))) + + # Random size + min.bucket <- round(runif(3, 1, 10)) + rf <- ranger(Species ~ ., iris, num.trees = 5, replace = FALSE, + min.bucket = min.bucket, keep.inbag = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(min.bucket, ncol = 5, nrow = 3))) + + # No factor outcome + rf <- ranger(Species ~ ., data.matrix(iris), num.trees = 5, replace = FALSE, + min.bucket = c(2, 3, 4), keep.inbag = TRUE, classification = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(c(2, 3, 4), ncol = 5, nrow = 3))) +}) + + diff --git a/tests/testthat/test_regression.R b/tests/testthat/test_regression.R index dd3bdd4e..8949f82d 100644 --- a/tests/testthat/test_regression.R +++ b/tests/testthat/test_regression.R @@ -7,9 +7,9 @@ context("ranger_reg") rg.reg <- ranger(Sepal.Length ~ ., data = iris) ## Basic tests (for all random forests equal) -test_that("regression result is of class ranger with 15 elements", { +test_that("regression result is of class ranger with 16 elements", { expect_is(rg.reg, "ranger") - expect_equal(length(rg.reg), 15) + expect_equal(length(rg.reg), 16) }) test_that("regression prediction returns numeric vector", { diff --git a/tests/testthat/test_survival.R b/tests/testthat/test_survival.R index 6226eb6f..358a4096 100644 --- a/tests/testthat/test_survival.R +++ b/tests/testthat/test_survival.R @@ -8,9 +8,9 @@ context("ranger_surv") rg.surv <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 10) ## Basic tests (for all random forests equal) -test_that("survival result is of class ranger with 17 elements", { +test_that("survival result is of class ranger with 18 elements", { expect_is(rg.surv, "ranger") - expect_equal(length(rg.surv), 17) + expect_equal(length(rg.surv), 18) }) test_that("results have right number of trees", {