Skip to content

Commit

Permalink
Merge pull request #495 from lorentzenchr/poisson
Browse files Browse the repository at this point in the history
Add Poisson splitting rule
  • Loading branch information
mnwright authored Jun 11, 2024
2 parents 682889e + 6e9d42b commit 10b73fd
Show file tree
Hide file tree
Showing 22 changed files with 420 additions and 74 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.16.1
Version: 0.16.2
Date: 2024-05-16
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.16.2
* Add Poisson splitting rule for regression trees

# ranger 0.16.1
* Set num.threads=2 as default; respect environment variables and options
* Add hierarchical shrinkage
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, poisson_tau, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, poisson_tau, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest)
}

numSmaller <- function(values, reference) {
Expand Down
3 changes: 2 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
splitrule <- 1
alpha <- 0
minprop <- 0
poisson.tau <- 1
case.weights <- c(0, 0)
use.case.weights <- FALSE
class.weights <- c(0, 0)
Expand Down Expand Up @@ -276,7 +277,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
prediction.mode, forest, snp.data, replace, probability,
unordered.factor.variables, use.unordered.factor.variables, save.memory, splitrule,
case.weights, use.case.weights, class.weights,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout,
predict.all, keep.inbag, sample.fraction, alpha, minprop, poisson.tau, holdout,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
Expand Down
32 changes: 27 additions & 5 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@
##' @param sample.fraction Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. For classification, this can be a vector of class-specific values.
##' @param case.weights Weights for sampling of training observations. Observations with larger weights will be selected with higher probability in the bootstrap (or subsampled) samples for the trees.
##' @param class.weights Weights for the outcome classes (in order of the factor levels) in the splitting rule (cost sensitive learning). Classification and probability prediction only. For classification the weights are also applied in the majority vote in terminal nodes.
##' @param splitrule Splitting rule. For classification and probability estimation "gini", "extratrees" or "hellinger" with default "gini". For regression "variance", "extratrees", "maxstat" or "beta" with default "variance". For survival "logrank", "extratrees", "C" or "maxstat" with default "logrank".
##' @param splitrule Splitting rule. For classification and probability estimation "gini", "extratrees" or "hellinger" with default "gini".
##' For regression "variance", "extratrees", "maxstat", "beta" or "poisson" with default "variance".
##' For survival "logrank", "extratrees", "C" or "maxstat" with default "logrank".
##' @param num.random.splits For "extratrees" splitrule.: Number of random splits to consider for each candidate splitting variable.
##' @param alpha For "maxstat" splitrule: Significance threshold to allow splitting.
##' @param minprop For "maxstat" splitrule: Lower quantile of covariate distribution to be considered for splitting.
##' @param poisson.tau For "poisson" splitrule: The coefficient of variation of the (expected) frequency is \eqn{1/\tau}.
##' If a terminal node has only 0 responses, the estimate is set to \eqn{\alpha 0 + (1-\alpha) mean(parent)} with \eqn{\alpha = samples(child) mean(parent) / (\tau + samples(child) mean(parent))}.
##' @param split.select.weights Numeric vector with weights between 0 and 1, used to calculate the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used.
##' @param always.split.variables Character vector with variable names to be always selected in addition to the \code{mtry} variables tried for splitting.
##' @param respect.unordered.factors Handling of unordered factor covariates. One of 'ignore', 'order' and 'partition'. For the "extratrees" splitrule the default is "partition" for all other splitrules 'ignore'. Alternatively TRUE (='order') or FALSE (='ignore') can be used. See below for details.
Expand Down Expand Up @@ -237,6 +241,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
replace = TRUE, sample.fraction = ifelse(replace, 1, 0.632),
case.weights = NULL, class.weights = NULL, splitrule = NULL,
num.random.splits = 1, alpha = 0.5, minprop = 0.1,
poisson.tau = 1,
split.select.weights = NULL, always.split.variables = NULL,
respect.unordered.factors = NULL,
scale.permutation.importance = FALSE,
Expand Down Expand Up @@ -818,6 +823,17 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
if ((is.factor(y) && nlevels(y) > 2) || (length(unique(y)) > 2)) {
stop("Error: Hellinger splitrule only implemented for binary classification.")
}
} else if (splitrule == "poisson") {
if (treetype == 3) {
splitrule.num <- 8
} else {
stop("Error: poisson splitrule applicable to regression data only.")
}

## Check for valid responses
if (min(y) < 0 || sum(y) <= 0) {
stop("Error: poisson splitrule applicable to regression data with non-positive outcome (y>=0 and sum(y)>0) only.")
}
} else {
stop("Error: Unknown splitrule.")
}
Expand All @@ -843,6 +859,10 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
if (num.random.splits > 1 && splitrule.num != 5) {
warning("Argument 'num.random.splits' ignored if splitrule is not 'extratrees'.")
}

if (!is.numeric(poisson.tau) || poisson.tau <= 0) {
stop("Error: Invalid value for poisson.tau, please give a positive number.")
}

## Unordered factors
if (respect.unordered.factors == "partition") {
Expand Down Expand Up @@ -879,6 +899,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
stop("Error: Unordered factor splitting not implemented for 'C' splitting rule.")
} else if (splitrule == "beta") {
stop("Error: Unordered factor splitting not implemented for 'beta' splitting rule.")
} else if (splitrule == "poisson") {
stop("Error: Unordered factor splitting not implemented for 'poisson' splitting rule.")
}
}

Expand Down Expand Up @@ -966,10 +988,10 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
prediction.mode, loaded.forest, snp.data,
replace, probability, unordered.factor.variables, use.unordered.factor.variables,
save.memory, splitrule.num, case.weights, use.case.weights, class.weights,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type,
num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth,
inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
predict.all, keep.inbag, sample.fraction, alpha, minprop, poisson.tau,
holdout, prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
node.stats, time.interest, use.time.interest)

if (length(result) == 0) {
Expand Down
2 changes: 1 addition & 1 deletion cpp_version/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void run_ranger(const ArgumentHandler& arg_handler, std::ostream& verbose_out) {
arg_handler.predict, arg_handler.impmeasure, arg_handler.targetpartitionsize, arg_handler.minbucket, arg_handler.splitweights,
arg_handler.alwayssplitvars, arg_handler.statusvarname, arg_handler.replace, arg_handler.catvars,
arg_handler.savemem, arg_handler.splitrule, arg_handler.caseweights, arg_handler.predall, arg_handler.fraction,
arg_handler.alpha, arg_handler.minprop, arg_handler.holdout, arg_handler.predictiontype,
arg_handler.alpha, arg_handler.minprop, arg_handler.tau, arg_handler.holdout, arg_handler.predictiontype,
arg_handler.randomsplits, arg_handler.maxdepth, arg_handler.regcoef, arg_handler.usedepth);

forest->run(true, !arg_handler.skipoob);
Expand Down
31 changes: 27 additions & 4 deletions cpp_version/src/utility/ArgumentHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace ranger {

ArgumentHandler::ArgumentHandler(int argc, char **argv) :
caseweights(""), depvarname(""), fraction(0), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict(
""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), nthreads(
""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), tau(DEFAULT_POISSON_TAU), nthreads(
DEFAULT_NUM_THREADS), predall(false), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), maxdepth(
DEFAULT_MAXDEPTH), file(""), impmeasure(DEFAULT_IMPORTANCE_MODE), targetpartitionsize(0), minbucket(0), mtry(0), outprefix(
"ranger_out"), probability(false), splitrule(DEFAULT_SPLITRULE), statusvarname(""), ntree(DEFAULT_NUM_TREE), replace(
Expand All @@ -33,7 +33,7 @@ ArgumentHandler::ArgumentHandler(int argc, char **argv) :
int ArgumentHandler::processArguments() {

// short options
char const *short_options = "A:C:D:F:HM:NOP:Q:R:S:U:XZa:b:c:d:f:hi:j:kl:m:n:o:pr:s:t:uvwy:z:";
char const *short_options = "A:C:D:F:HM:NOP:Q:R:S:T:U:XZa:b:c:d:f:hi:j:kl:m:n:o:pr:s:t:uvwy:z:";

// long options: longname, no/optional/required argument?, flag(not used!), shortname
const struct option long_options[] = {
Expand All @@ -50,6 +50,7 @@ int ArgumentHandler::processArguments() {
{ "predictiontype", required_argument, 0, 'Q'},
{ "randomsplits", required_argument, 0, 'R'},
{ "splitweights", required_argument, 0, 'S'},
{ "tau", required_argument, 0, 'T'},
{ "nthreads", required_argument, 0, 'U'},
{ "predall", no_argument, 0, 'X'},
{ "version", no_argument, 0, 'Z'},
Expand Down Expand Up @@ -178,6 +179,20 @@ int ArgumentHandler::processArguments() {
case 'S':
splitweights = optarg;
break;

case 'T':
try {
double temp = std::stod(optarg);
if (temp <= 0) {
throw std::runtime_error("");
} else {
tau = temp;
}
} catch (...) {
throw std::runtime_error(
"Illegal argument for option 'tau'. Please give a positive value. See '--help' for details.");
}
break;

case 'U':
try {
Expand Down Expand Up @@ -352,6 +367,9 @@ int ArgumentHandler::processArguments() {
case 7:
splitrule = HELLINGER;
break;
case 8:
splitrule = POISSON;
break;
default:
throw std::runtime_error("");
break;
Expand Down Expand Up @@ -512,7 +530,8 @@ void ArgumentHandler::checkArguments() {
if (((splitrule == AUC || splitrule == AUC_IGNORE_TIES) && treetype != TREE_SURVIVAL)
|| (splitrule == MAXSTAT && (treetype != TREE_SURVIVAL && treetype != TREE_REGRESSION))
|| (splitrule == BETA && treetype != TREE_REGRESSION)
|| (splitrule == HELLINGER && treetype != TREE_CLASSIFICATION && treetype != TREE_PROBABILITY)) {
|| (splitrule == HELLINGER && treetype != TREE_CLASSIFICATION && treetype != TREE_PROBABILITY)
|| (splitrule == POISSON && treetype != TREE_REGRESSION)) {
throw std::runtime_error("Illegal splitrule selected. See '--help' for details.");
}

Expand Down Expand Up @@ -658,8 +677,9 @@ void ArgumentHandler::displayHelp() {
<< " RULE = 4: MAXSTAT for Survival and Regression, not available for Classification."
<< std::endl;
std::cout << " " << " RULE = 5: ExtraTrees for all tree types." << std::endl;
std::cout << " " << " RULE = 6: BETA for regression, only for (0,1) bounded outcomes." << std::endl;
std::cout << " " << " RULE = 6: BETA for Regression, only for (0,1) bounded outcomes." << std::endl;
std::cout << " " << " RULE = 7: Hellinger for Classification, not available for Regression and Survival." << std::endl;
std::cout << " " << " RULE = 8: Poisson for Regression, not available for Classification and Survival." << std::endl;
std::cout << " " << " (Default: 1)" << std::endl;
std::cout << " "
<< "--randomsplits N Number of random splits to consider for each splitting variable (ExtraTrees splitrule only)."
Expand All @@ -670,6 +690,9 @@ void ArgumentHandler::displayHelp() {
std::cout << " "
<< "--minprop VAL Lower quantile of covariate distribtuion to be considered for splitting (MAXSTAT splitrule only)."
<< std::endl;
std::cout << " "
<< "--tau VAL Tau parameter for Poisson splitting (Poisson splitrule only)."
<< std::endl;
std::cout << " " << "--caseweights FILE Filename of case weights file." << std::endl;
std::cout << " "
<< "--holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable "
Expand Down
1 change: 1 addition & 0 deletions cpp_version/src/utility/ArgumentHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ArgumentHandler {
PredictionType predictiontype;
uint randomsplits;
std::string splitweights;
double tau;
uint nthreads;
bool predall;

Expand Down
2 changes: 1 addition & 1 deletion cpp_version/src/version.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#ifndef RANGER_VERSION
#define RANGER_VERSION "0.16.1"
#define RANGER_VERSION "0.16.2"
#endif
8 changes: 7 additions & 1 deletion man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 10b73fd

Please sign in to comment.