diff --git a/DESCRIPTION b/DESCRIPTION index c842a425..978e2b99 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: ranger Type: Package Title: A Fast Implementation of Random Forests -Version: 0.16.0 -Date: 2023-11-09 +Version: 0.16.1 +Date: 2024-05-15 Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb] Maintainer: Marvin N. Wright Description: A fast implementation of Random Forests, particularly suited for high @@ -19,7 +19,7 @@ Suggests: survival, testthat Encoding: UTF-8 -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 URL: http://imbs-hl.github.io/ranger/, https://github.com/imbs-hl/ranger BugReports: https://github.com/imbs-hl/ranger/issues diff --git a/NEWS.md b/NEWS.md index ad66d327..a583b20c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,7 @@ +# ranger 0.16.1 +* Allow vector min.node.size and min.bucket for class-specific limits + # ranger 0.16.0 * New CRAN version diff --git a/R/ranger.R b/R/ranger.R index 6d56d4d4..65e20df8 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -109,8 +109,8 @@ ##' @param importance Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression and the sum of test statistics (see \code{splitrule}) for survival. ##' @param write.forest Save \code{ranger.forest} object, required for prediction. Set to \code{FALSE} to reduce memory usage if no prediction intended. ##' @param probability Grow a probability forest as in Malley et al. (2012). -##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. -##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. +##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. For classification, this can be a vector of class-specific values. +##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. For classification, this can be a vector of class-specific values. ##' @param max.depth Maximal tree depth. A value of NULL or 0 (the default) corresponds to unlimited depth, 1 to tree stumps (1 split per tree). ##' @param replace Sample with replacement. ##' @param sample.fraction Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. For classification, this can be a vector of class-specific values. @@ -359,6 +359,15 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, stop("Error: Unsupported type of dependent variable.") } + ## Number of levels + if (treetype %in% c(1, 9)) { + if (is.factor(y)) { + num_levels <- nlevels(y) + } else { + num_levels <- length(unique(y)) + } + } + ## Quantile prediction only for regression if (quantreg && treetype != 3) { stop("Error: Quantile prediction implemented only for regression outcomes.") @@ -522,16 +531,46 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, ## Minimum node size if (is.null(min.node.size)) { min.node.size <- 0 - } else if (!is.numeric(min.node.size) || min.node.size < 0) { - stop("Error: Invalid value for min.node.size") + } else if (!is.numeric(min.node.size)) { + stop("Error: Invalid value for min.node.size.") + } + if (length(min.node.size) > 1) { + if (!(treetype %in% c(1, 9))) { + stop("Error: Invalid value for min.node.size. Vector values only valid for classification forests.") + } + if (any(min.node.size < 0)) { + stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.") + } + if (length(min.node.size) != num_levels) { + stop("Error: Invalid value for min.node.size Expecting ", num_levels, " values, provided ", length(min.node.size), ".") + } + } else { + if (min.node.size < 0) { + stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.") + } } ## Minimum bucket size if (is.null(min.bucket)) { min.bucket <- 0 - } else if (!is.numeric(min.bucket) || min.bucket < 0) { + } else if (!is.numeric(min.bucket)) { stop("Error: Invalid value for min.bucket") } + if (length(min.bucket) > 1) { + if (!(treetype %in% c(1, 9))) { + stop("Error: Invalid value for min.bucket Vector values only valid for classification forests.") + } + if (any(min.bucket < 0)) { + stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.") + } + if (length(min.bucket) != num_levels) { + stop("Error: Invalid value for min.bucket Expecting ", num_levels, " values, provided ", length(min.bucket), ".") + } + } else { + if (min.bucket < 0) { + stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.") + } + } ## Tree depth if (is.null(max.depth)) { @@ -554,8 +593,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, if (sum(sample.fraction) <= 0) { stop("Error: Invalid value for sample.fraction. Sum of values must be >0.") } - if (length(sample.fraction) != nlevels(y)) { - stop("Error: Invalid value for sample.fraction. Expecting ", nlevels(y), " values, provided ", length(sample.fraction), ".") + if (length(sample.fraction) != num_levels) { + stop("Error: Invalid value for sample.fraction. Expecting ", num_levels, " values, provided ", length(sample.fraction), ".") } if (!replace & any(sample.fraction * length(y) > table(y))) { idx <- which(sample.fraction * length(y) > table(y))[1] diff --git a/man/ranger.Rd b/man/ranger.Rd index 61c6e5df..6b7465b8 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -64,9 +64,9 @@ ranger( \item{probability}{Grow a probability forest as in Malley et al. (2012).} -\item{min.node.size}{Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability.} +\item{min.node.size}{Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. For classification, this can be a vector of class-specific values.} -\item{min.bucket}{Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types.} +\item{min.bucket}{Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. For classification, this can be a vector of class-specific values.} \item{max.depth}{Maximal tree depth. A value of NULL or 0 (the default) corresponds to unlimited depth, 1 to tree stumps (1 split per tree).} diff --git a/src/Forest.cpp b/src/Forest.cpp index 8c7a4242..7100cca3 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -27,7 +27,7 @@ namespace ranger { Forest::Forest() : - verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size(0), min_bucket(0), num_independent_variables(0), seed(0), num_samples( + verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size({0}), min_bucket({0}), num_independent_variables(0), seed(0), num_samples( 0), prediction_mode(false), memory_mode(MEM_DOUBLE), sample_with_replacement(true), memory_saving_splitting( false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction( { 1 }), holdout( false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth( @@ -62,6 +62,9 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode if (!load_forest_filename.empty()) { prediction_mode = true; } + + std::vector min_node_size_vector = { min_node_size }; + std::vector min_bucket_vector = { min_bucket }; // Sample fraction default and convert to vector if (sample_fraction == 0) { @@ -79,7 +82,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode // Call other init function init(loadDataFromFile(input_file), mtry, output_prefix, num_trees, seed, num_threads, importance_mode, - min_node_size, min_bucket, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting, + min_node_size_vector, min_bucket_vector, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting, splitrule, predict_all, sample_fraction_vector, alpha, minprop, holdout, prediction_type, num_random_splits, false, max_depth, regularization_factor, regularization_usedepth, false); @@ -133,7 +136,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode // #nocov end void Forest::initR(std::unique_ptr input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed, - uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, const std::vector& always_split_variable_names, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, std::vector& case_weights, @@ -178,7 +181,7 @@ void Forest::initR(std::unique_ptr input_data, uint mtry, uint num_trees, } void Forest::init(std::unique_ptr input_data, uint mtry, std::string output_prefix, - uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector& sample_fraction, double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, bool order_snps, @@ -323,7 +326,7 @@ void Forest::writeOutput() { *verbose_out << "Sample size: " << num_samples << std::endl; *verbose_out << "Number of independent variables: " << num_independent_variables << std::endl; *verbose_out << "Mtry: " << mtry << std::endl; - *verbose_out << "Target node size: " << min_node_size << std::endl; + *verbose_out << "Target node size: " << min_node_size[0] << std::endl; *verbose_out << "Variable importance mode: " << importance_mode << std::endl; *verbose_out << "Memory mode: " << memory_mode << std::endl; *verbose_out << "Seed: " << seed << std::endl; @@ -473,7 +476,7 @@ void Forest::grow() { } trees[i]->init(data.get(), mtry, num_samples, tree_seed, &deterministic_varIDs, tree_split_select_weights, - importance_mode, min_node_size, min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights, + importance_mode, &min_node_size, &min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights, tree_manual_inbag, keep_inbag, &sample_fraction, alpha, minprop, holdout, num_random_splits, max_depth, ®ularization_factor, regularization_usedepth, &split_varIDs_used, save_node_stats); } diff --git a/src/Forest.h b/src/Forest.h index 5b297202..e9279154 100644 --- a/src/Forest.h +++ b/src/Forest.h @@ -48,7 +48,7 @@ class Forest { bool holdout, PredictionType prediction_type, uint num_random_splits, uint max_depth, const std::vector& regularization_factor, bool regularization_usedepth); void initR(std::unique_ptr input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed, - uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, const std::vector& always_split_variable_names, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, @@ -58,7 +58,7 @@ class Forest { const std::vector& regularization_factor, bool regularization_usedepth, bool node_stats); void init(std::unique_ptr input_data, uint mtry, std::string output_prefix, - uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector& min_node_size, std::vector& min_bucket, bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector& sample_fraction, double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, @@ -119,10 +119,10 @@ class Forest { uint getMtry() const { return mtry; } - uint getMinNodeSize() const { + const std::vector& getMinNodeSize() const { return min_node_size; } - uint getMinBucket() const { + const std::vector& getMinBucket() const { return min_bucket; } size_t getNumIndependentVariables() const { @@ -209,8 +209,8 @@ class Forest { std::vector dependent_variable_names; // time,status for survival size_t num_trees; uint mtry; - uint min_node_size; - uint min_bucket; + std::vector min_node_size; + std::vector min_bucket; size_t num_independent_variables; uint seed; size_t num_samples; diff --git a/src/ForestClassification.cpp b/src/ForestClassification.cpp index bb6861f7..1f26aeb3 100644 --- a/src/ForestClassification.cpp +++ b/src/ForestClassification.cpp @@ -54,13 +54,13 @@ void ForestClassification::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET; } // Create class_values and response_classIDs diff --git a/src/ForestProbability.cpp b/src/ForestProbability.cpp index 40922554..817b2de5 100644 --- a/src/ForestProbability.cpp +++ b/src/ForestProbability.cpp @@ -59,13 +59,13 @@ void ForestProbability::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_PROBABILITY; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_PROBABILITY; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET; } // Create class_values and response_classIDs diff --git a/src/ForestRegression.cpp b/src/ForestRegression.cpp index 7c1bb326..6328ac20 100644 --- a/src/ForestRegression.cpp +++ b/src/ForestRegression.cpp @@ -48,13 +48,13 @@ void ForestRegression::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_REGRESSION; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_REGRESSION; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET; } // Error if beta splitrule used with data outside of [0,1] diff --git a/src/ForestSurvival.cpp b/src/ForestSurvival.cpp index 3b88e312..9a31f3da 100644 --- a/src/ForestSurvival.cpp +++ b/src/ForestSurvival.cpp @@ -98,13 +98,13 @@ void ForestSurvival::initInternal() { } // Set minimal node size - if (min_node_size == 0) { - min_node_size = DEFAULT_MIN_NODE_SIZE_SURVIVAL; + if (min_node_size.size() == 1 && min_node_size[0] == 0) { + min_node_size[0] = DEFAULT_MIN_NODE_SIZE_SURVIVAL; } // Set minimal bucket size - if (min_bucket == 0) { - min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL; + if (min_bucket.size() == 1 && min_bucket[0] == 0) { + min_bucket[0] = DEFAULT_MIN_BUCKET_SURVIVAL; } // Sort data if extratrees and not memory saving mode diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index a47c2234..254fbe78 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,7 +13,7 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // rangerCpp -Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, bool node_stats, std::vector& time_interest, bool use_time_interest); +Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, bool node_stats, std::vector& time_interest, bool use_time_interest); RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP node_statsSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; @@ -29,8 +29,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< uint >::type num_threads(num_threadsSEXP); Rcpp::traits::input_parameter< bool >::type write_forest(write_forestSEXP); Rcpp::traits::input_parameter< uint >::type importance_mode_r(importance_mode_rSEXP); - Rcpp::traits::input_parameter< uint >::type min_node_size(min_node_sizeSEXP); - Rcpp::traits::input_parameter< uint >::type min_bucket(min_bucketSEXP); + Rcpp::traits::input_parameter< std::vector& >::type min_node_size(min_node_sizeSEXP); + Rcpp::traits::input_parameter< std::vector& >::type min_bucket(min_bucketSEXP); Rcpp::traits::input_parameter< std::vector>& >::type split_select_weights(split_select_weightsSEXP); Rcpp::traits::input_parameter< bool >::type use_split_select_weights(use_split_select_weightsSEXP); Rcpp::traits::input_parameter< std::vector& >::type always_split_variable_names(always_split_variable_namesSEXP); diff --git a/src/Tree.cpp b/src/Tree.cpp index 57d3dfbd..fd97f686 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -39,7 +39,7 @@ Tree::Tree(std::vector>& child_nodeIDs, std::vector& } void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, - std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + std::vector* split_select_weights, ImportanceMode importance_mode, std::vector* min_node_size, std::vector* min_bucket, bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, std::vector* manual_inbag, bool keep_inbag, std::vector* sample_fraction, double alpha, double minprop, bool holdout, uint num_random_splits, uint max_depth, std::vector* regularization_factor, @@ -90,7 +90,7 @@ void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std: void Tree::grow(std::vector* variable_importance) { // Allocate memory for tree growing allocateMemory(); - + this->variable_importance = variable_importance; // Bootstrap, dependent if weighted or not and with or without replacement @@ -307,7 +307,7 @@ void Tree::createPossibleSplitVarSubset(std::vector& result) { } bool Tree::splitNode(size_t nodeID) { - + // Select random subset of variables to possibly split at std::vector possible_split_varIDs; createPossibleSplitVarSubset(possible_split_varIDs); diff --git a/src/Tree.h b/src/Tree.h index 101c300d..17b98ff0 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -36,7 +36,7 @@ class Tree { Tree& operator=(const Tree&) = delete; void init(const Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, - std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, + std::vector* split_select_weights, ImportanceMode importance_mode, std::vector* min_node_size, std::vector* min_bucket, bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, std::vector* manual_inbag, bool keep_inbag, std::vector* sample_fraction, double alpha, double minprop, bool holdout, uint num_random_splits, @@ -166,10 +166,10 @@ class Tree { size_t num_samples_oob; // Minimum node size to split, nodes of smaller size can be produced - uint min_node_size; + std::vector* min_node_size; // Minimum bucket size, minimum number of samples in each node - uint min_bucket; + std::vector* min_bucket; // Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable // Deterministic variables are always selected diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index bbb3b581..28ca371f 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -85,7 +85,7 @@ bool TreeClassification::splitNodeInternal(size_t nodeID, std::vector& p } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if ((min_node_size->size() == 1 && num_samples_node <= (*min_node_size)[0]) || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { split_values[nodeID] = estimate(nodeID); return true; } @@ -166,9 +166,19 @@ bool TreeClassification::findBestSplit(size_t nodeID, std::vector& possi uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } + + // Stop if class-wise minimal node size reached + if (min_node_size->size() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } -// Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + // TODO: Possible to stop early for class-wise min_bucket? + // Stop early if no split posssible + if (min_bucket->size() > 1 || (num_samples_node >= 2 * (*min_bucket)[0])) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -283,7 +293,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -317,6 +327,21 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -375,7 +400,7 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -409,6 +434,21 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -486,7 +526,7 @@ void TreeClassification::findBestSplitValueUnordered(size_t nodeID, size_t varID size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -516,6 +556,21 @@ void TreeClassification::findBestSplitValueUnordered(size_t nodeID, size_t varID // Decrease of impurity decrease = sum_left / (double) n_left + sum_right / (double) n_right; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -544,9 +599,19 @@ bool TreeClassification::findBestSplitExtraTrees(size_t nodeID, std::vectorsize() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } + // TODO: Possible to stop early for class-wise min_bucket? // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (min_bucket->size() > 1 || (num_samples_node >= 2 * (*min_bucket)[0])) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -657,7 +722,7 @@ void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varI } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0])) { continue; } @@ -671,6 +736,21 @@ void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varI sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; @@ -768,7 +848,7 @@ void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, si size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -782,6 +862,21 @@ void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, si sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right; diff --git a/src/TreeProbability.cpp b/src/TreeProbability.cpp index 53d630b7..5cb61f46 100644 --- a/src/TreeProbability.cpp +++ b/src/TreeProbability.cpp @@ -89,7 +89,7 @@ bool TreeProbability::splitNodeInternal(size_t nodeID, std::vector& poss } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if ((min_node_size->size() == 1 && num_samples_node <= (*min_node_size)[0]) || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { addToTerminalNodes(nodeID); } @@ -170,9 +170,19 @@ bool TreeProbability::findBestSplit(size_t nodeID, std::vector& possible uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } + + // Stop if class-wise minimal node size reached + if (min_node_size->size() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } + // TODO: Possible to stop early for class-wise min_bucket? // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (min_bucket->size() > 1 || (num_samples_node >= 2 * (*min_bucket)[0])) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -287,7 +297,7 @@ void TreeProbability::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -321,6 +331,21 @@ void TreeProbability::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -379,7 +404,7 @@ void TreeProbability::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -413,6 +438,21 @@ void TreeProbability::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -490,7 +530,7 @@ void TreeProbability::findBestSplitValueUnordered(size_t nodeID, size_t varID, s size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -520,6 +560,21 @@ void TreeProbability::findBestSplitValueUnordered(size_t nodeID, size_t varID, s // Decrease of impurity decrease = sum_left / (double) n_left + sum_right / (double) n_right; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Regularization regularize(decrease, varID); @@ -548,9 +603,19 @@ bool TreeProbability::findBestSplitExtraTrees(size_t nodeID, std::vector uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } + + // Stop if class-wise minimal node size reached + if (min_node_size->size() > 1) { + for (size_t j = 0; j < num_classes; ++j) { + if (class_counts[j] < (*min_node_size)[j]) { + return true; + } + } + } + // TODO: Possible to stop early for class-wise min_bucket? // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (min_bucket->size() > 1 || (num_samples_node >= 2 * (*min_bucket)[0])) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -661,7 +726,7 @@ void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0])) { continue; } @@ -675,6 +740,21 @@ void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; @@ -772,7 +852,7 @@ void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_ size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { continue; } @@ -786,6 +866,21 @@ void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_ sum_right += (*class_weights)[j] * class_count_right * class_count_right; sum_left += (*class_weights)[j] * class_count_left * class_count_left; } + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_left = class_counts[j] - class_counts_right[j]; + if (class_count_left < (*min_bucket)[j] || class_counts_right[j] < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } // Decrease of impurity double decrease = sum_left / (double) n_left + sum_right / (double) n_right; diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp index c272695b..ec59528f 100644 --- a/src/TreeRegression.cpp +++ b/src/TreeRegression.cpp @@ -68,7 +68,7 @@ bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector& possi } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { split_values[nodeID] = estimate(nodeID); return true; } @@ -150,7 +150,7 @@ bool TreeRegression::findBestSplit(size_t nodeID, std::vector& possible_ } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -261,7 +261,7 @@ void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, doubl } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -323,7 +323,7 @@ void TreeRegression::findBestSplitValueLargeQ(size_t nodeID, size_t varID, doubl } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -405,7 +405,7 @@ void TreeRegression::findBestSplitValueUnordered(size_t nodeID, size_t varID, do size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -548,7 +548,7 @@ bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector& } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -658,7 +658,7 @@ void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, d } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0]) { continue; } @@ -758,7 +758,7 @@ void TreeRegression::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t size_t n_left = num_samples_node - n_right; // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { continue; } @@ -793,7 +793,7 @@ bool TreeRegression::findBestSplitBeta(size_t nodeID, std::vector& possi } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables find best split value for (auto& varID : possible_split_varIDs) { @@ -886,7 +886,7 @@ void TreeRegression::findBestSplitValueBeta(size_t nodeID, size_t varID, double } // Stop if minimal bucket size reached - if (n_left < min_bucket || n_right[i] < min_bucket) { + if (n_left < (*min_bucket)[0] || n_right[i] < (*min_bucket)[0]) { continue; } diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index 1c60ba8b..7ae94768 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -140,7 +140,7 @@ bool TreeSurvival::findBestSplit(size_t nodeID, std::vector& possible_sp } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { computeSurvival(nodeID); } @@ -148,7 +148,7 @@ bool TreeSurvival::findBestSplit(size_t nodeID, std::vector& possible_sp } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -200,7 +200,7 @@ bool TreeSurvival::findBestSplitMaxstat(size_t nodeID, std::vector& poss size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID]; // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { computeDeathCounts(nodeID); computeSurvival(nodeID); @@ -414,7 +414,7 @@ void TreeSurvival::findBestSplitValueLogRank(size_t nodeID, size_t varID, double // Stop if minimal bucket size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child[i]; - if (num_samples_right_child[i] < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child[i] < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } @@ -520,7 +520,7 @@ void TreeSurvival::findBestSplitValueLogRankUnordered(size_t nodeID, size_t varI // Stop if minimal bucket size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child; - if (num_samples_right_child < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } @@ -611,7 +611,7 @@ void TreeSurvival::findBestSplitValueAUC(size_t nodeID, size_t varID, double& be for (size_t i = 0; i < num_splits; ++i) { // Do not consider this split point if fewer than min_bucket samples in one node size_t num_samples_right_child = num_node_samples - num_samples_left_child[i]; - if (num_samples_left_child[i] < min_bucket || num_samples_right_child < min_bucket) { + if (num_samples_left_child[i] < (*min_bucket)[0] || num_samples_right_child < (*min_bucket)[0]) { continue; } else { double auc = fabs((num_count[i] / 2) / num_total[i] - 0.5); @@ -711,7 +711,7 @@ bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector& p } // Stop if maximum node size or depth reached - if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { + if (num_samples_node <= (*min_node_size)[0] || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth)) { if (!save_node_stats) { computeSurvival(nodeID); } @@ -719,7 +719,7 @@ bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector& p } // Stop early if no split posssible - if (num_samples_node >= 2 * min_bucket) { + if (num_samples_node >= 2 * (*min_bucket)[0]) { // For all possible split variables for (auto& varID : possible_split_varIDs) { @@ -805,7 +805,7 @@ void TreeSurvival::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, dou // Stop if minimal node size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child[i]; - if (num_samples_right_child[i] < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child[i] < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } @@ -934,7 +934,7 @@ void TreeSurvival::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t v // Stop if minimal node size reached size_t num_samples_left_child = num_samples_node - num_samples_right_child; - if (num_samples_right_child < min_bucket || num_samples_left_child < min_bucket) { + if (num_samples_right_child < (*min_bucket)[0] || num_samples_left_child < (*min_bucket)[0]) { continue; } diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index c8c4fed2..66e34c0e 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -50,7 +50,7 @@ using namespace ranger; // [[Rcpp::export]] Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, - bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, + bool write_forest, uint importance_mode_r, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, diff --git a/tests/testthat/test_ranger.R b/tests/testthat/test_ranger.R index 0f2c0118..40c788c0 100644 --- a/tests/testthat/test_ranger.R +++ b/tests/testthat/test_ranger.R @@ -86,6 +86,14 @@ test_that("Inbag counts match sample fraction, classification", { expect_equal(unique(colSums(inbag[dat$Species == "setosa", ])), 15) expect_equal(unique(colSums(inbag[dat$Species == "versicolor", ])), 30) expect_equal(unique(colSums(inbag[dat$Species == "virginica", ])), 45) + + ## No factor outcome + rf <- ranger(Species ~ ., data.matrix(iris), num.trees = 5, sample.fraction = c(0.2, 0.3, 0.4), + replace = TRUE, keep.inbag = TRUE, classification = TRUE) + inbag <- do.call(cbind, rf$inbag.counts) + expect_equal(unique(colSums(inbag[iris$Species == "setosa", ])), 30) + expect_equal(unique(colSums(inbag[iris$Species == "versicolor", ])), 45) + expect_equal(unique(colSums(inbag[iris$Species == "virginica", ])), 60) }) test_that("Inbag counts match sample fraction, probability", { @@ -104,6 +112,14 @@ test_that("Inbag counts match sample fraction, probability", { expect_equal(unique(colSums(inbag[1:50, ])), 15) expect_equal(unique(colSums(inbag[51:100, ])), 30) expect_equal(unique(colSums(inbag[101:150, ])), 45) + + ## No factor outcome + rf <- ranger(Species ~ ., data.matrix(iris), num.trees = 5, sample.fraction = c(0.2, 0.3, 0.4), + replace = TRUE, keep.inbag = TRUE, probability = TRUE) + inbag <- do.call(cbind, rf$inbag.counts) + expect_equal(unique(colSums(inbag[1:50, ])), 30) + expect_equal(unique(colSums(inbag[51:100, ])), 45) + expect_equal(unique(colSums(inbag[101:150, ])), 60) }) test_that("as.factor() in formula works", { @@ -382,3 +398,85 @@ test_that("min.bucket creates nodes of correct size", { })) expect_gte(smallest_node, min.bucket) }) + +test_that("Vector min.bucket creates nodes of correct size", { + + # Size 2,3,4 + rf <- ranger(Species ~ ., iris, num.trees = 5, replace = FALSE, + min.bucket = c(2, 3, 4), keep.inbag = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(c(2, 3, 4), ncol = 5, nrow = 3))) + + # Size 4,3,2 + rf <- ranger(Species ~ ., iris, num.trees = 5, replace = FALSE, + min.bucket = c(4, 3, 2), keep.inbag = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(c(4, 3, 2), ncol = 5, nrow = 3))) + + # Random size + min.bucket <- round(runif(3, 1, 10)) + rf <- ranger(Species ~ ., iris, num.trees = 5, replace = FALSE, + min.bucket = min.bucket, keep.inbag = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(min.bucket, ncol = 5, nrow = 3))) + + # No factor outcome + rf <- ranger(Species ~ ., data.matrix(iris), num.trees = 5, replace = FALSE, + min.bucket = c(2, 3, 4), keep.inbag = TRUE, classification = TRUE) + pred <- predict(rf, iris, type = "terminalNodes")$prediction + inbag <- sapply(rf$inbag.counts, function(x) x == 1) + + smallest_nodes <- sapply(1:ncol(pred), function(i) { + pred1 <- pred[which(inbag[, i][1:50]), i] + pred2 <- pred[which(inbag[, i][51:100]) + 50, i] + pred3 <- pred[which(inbag[, i][101:150]) + 100, i] + + pred <- rbind(data.frame(class = 1, node = pred1), + data.frame(class = 2, node = pred2), + data.frame(class = 3, node = pred3)) + apply(table(pred), 1, min) + }) + + expect_true(all(smallest_nodes >= matrix(c(2, 3, 4), ncol = 5, nrow = 3))) +}) + +