Skip to content

Commit

Permalink
add time.interest to reduce unique time points
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Nov 3, 2023
1 parent 5a04d93 commit 9093780
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 33 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.15.3
Date: 2023-07-19
Version: 0.15.4
Date: 2023-11-03
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.15.4
* Add time.interest option to restrict unique survival times (faster and saves memory)

# ranger 0.15.3
* Fix min bucket option in C++ version

Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest)
}

numSmaller <- function(values, reference) {
Expand Down
5 changes: 4 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
regularization.factor <- c(0, 0)
use.regularization.factor <- FALSE
regularization.usedepth <- FALSE
time.interest <- c(0, 0)
use.time.interest <- FALSE

## Use sparse matrix
if (inherits(x, "dgCMatrix")) {
Expand All @@ -273,7 +275,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth)
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
27 changes: 25 additions & 2 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
##' @param inbag Manually set observations per tree. List of size num.trees, containing inbag counts for each observation. Can be used for stratified sampling.
##' @param holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable importance and prediction error.
##' @param quantreg Prepare quantile prediction as in quantile regression forests (Meinshausen 2006). Regression only. Set \code{keep.inbag = TRUE} to prepare out-of-bag quantile prediction.
##' @param time.interest Time points of interest (survival only). Can be \code{NULL} (default, use all observed time points), a vector of time points or a single number to use as many time points (grid over observed time points).
##' @param oob.error Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests.
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems.
Expand Down Expand Up @@ -222,7 +223,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
local.importance = FALSE,
regularization.factor = 1, regularization.usedepth = FALSE,
keep.inbag = FALSE, inbag = NULL, holdout = FALSE,
quantreg = FALSE, oob.error = TRUE,
quantreg = FALSE, time.interest = NULL, oob.error = TRUE,
num.threads = NULL, save.memory = FALSE,
verbose = TRUE, seed = NULL,
dependent.variable.name = NULL, status.variable.name = NULL,
Expand Down Expand Up @@ -822,6 +823,27 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}
}

## Time of interest
if (is.null(time.interest)) {
time.interest <- c(0, 0)
use.time.interest <- FALSE
} else {
use.time.interest <- TRUE
if (treetype != 5) {
stop("Error: time.interest only applicable to survival forests.")
}
if (is.numeric(time.interest) & length(time.interest) == 1) {
if (time.interest < 1) {
stop("Error: time.interest must be a positive integer.")
}
# Grid over observed time points
time <- sort(unique(y[, 1]))
time.interest <- time[unique(round(seq.int(1, length(time), length.out = time.interest)))]
} else {
time.interest <- sort(unique(time.interest))
}
}

## Prediction mode always false. Use predict.ranger() method.
prediction.mode <- FALSE
predict.all <- FALSE
Expand Down Expand Up @@ -873,7 +895,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type,
num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth,
inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth)
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
3 changes: 3 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 40 additions & 21 deletions src/ForestSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,39 @@ void ForestSurvival::loadForest(size_t num_trees, std::vector<std::vector<std::v
equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
}

void ForestSurvival::setUniqueTimepoints(const std::vector<double>& time_interest) {

if (time_interest.empty()) {
// Use all observed unique time points
std::set<double> unique_timepoint_set;
for (size_t i = 0; i < num_samples; ++i) {
unique_timepoint_set.insert(data->get_y(i, 0));
}
unique_timepoints.reserve(unique_timepoint_set.size());
for (auto& t : unique_timepoint_set) {
unique_timepoints.push_back(t);
}
} else {
// Use the supplied time points of interest
unique_timepoints = time_interest;
}

// Create response_timepointIDs
for (size_t i = 0; i < num_samples; ++i) {
double value = data->get_y(i, 0);

// If timepoint is already in unique_timepoints, use ID. Else create a new one.
uint timepointID = 0;
if (value > unique_timepoints[0]) {
timepointID = std::upper_bound(unique_timepoints.begin(), unique_timepoints.end(), value) - 1 - unique_timepoints.begin();
}
if (timepointID < 0) {
timepointID = 0;
}
response_timepointIDs.push_back(timepointID);
}
}

std::vector<std::vector<std::vector<double>>> ForestSurvival::getChf() const {
std::vector<std::vector<std::vector<double>>> result;
result.reserve(num_trees);
Expand Down Expand Up @@ -70,34 +103,20 @@ void ForestSurvival::initInternal() {
min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL;
}

// Create unique timepoints
if (!prediction_mode) {
std::set<double> unique_timepoint_set;
for (size_t i = 0; i < num_samples; ++i) {
unique_timepoint_set.insert(data->get_y(i, 0));
}
unique_timepoints.reserve(unique_timepoint_set.size());
for (auto& t : unique_timepoint_set) {
unique_timepoints.push_back(t);
}

// Create response_timepointIDs
for (size_t i = 0; i < num_samples; ++i) {
double value = data->get_y(i, 0);

// If timepoint is already in unique_timepoints, use ID. Else create a new one.
uint timepointID = find(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin();
response_timepointIDs.push_back(timepointID);
}
}

// Sort data if extratrees and not memory saving mode
if (splitrule == EXTRATREES && !memory_saving_splitting) {
data->sort();
}
}

void ForestSurvival::growInternal() {

// If unique time points not set, use observed times
if (unique_timepoints.empty()) {
setUniqueTimepoints(std::vector<double>());
}


trees.reserve(num_trees);
for (size_t i = 0; i < num_trees; ++i) {
trees.push_back(std::make_unique<TreeSurvival>(&unique_timepoints, &response_timepointIDs));
Expand Down
2 changes: 2 additions & 0 deletions src/ForestSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ForestSurvival: public Forest {
std::vector<std::vector<size_t>>& forest_split_varIDs, std::vector<std::vector<double>>& forest_split_values,
std::vector<std::vector<std::vector<double>> >& forest_chf, std::vector<double>& unique_timepoints,
std::vector<bool>& is_ordered_variable);

void setUniqueTimepoints(const std::vector<double>& time_interest);

std::vector<std::vector<std::vector<double>>> getChf() const;

Expand Down
10 changes: 6 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// rangerCpp
Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector<std::vector<double>>& split_select_weights, bool use_split_select_weights, std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector<double>& case_weights, bool use_case_weights, std::vector<double>& class_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix<double>& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector<std::vector<size_t>>& inbag, bool use_inbag, std::vector<double>& regularization_factor, bool use_regularization_factor, bool regularization_usedepth);
RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP) {
Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector<std::vector<double>>& split_select_weights, bool use_split_select_weights, std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector<double>& case_weights, bool use_case_weights, std::vector<double>& class_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix<double>& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector<std::vector<size_t>>& inbag, bool use_inbag, std::vector<double>& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, std::vector<double>& time_interest, bool use_time_interest);
RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand Down Expand Up @@ -65,7 +65,9 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< std::vector<double>& >::type regularization_factor(regularization_factorSEXP);
Rcpp::traits::input_parameter< bool >::type use_regularization_factor(use_regularization_factorSEXP);
Rcpp::traits::input_parameter< bool >::type regularization_usedepth(regularization_usedepthSEXP);
rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth));
Rcpp::traits::input_parameter< std::vector<double>& >::type time_interest(time_interestSEXP);
Rcpp::traits::input_parameter< bool >::type use_time_interest(use_time_interestSEXP);
rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -96,7 +98,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 47},
{"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 49},
{"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2},
{"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3},
{NULL, NULL, 0}
Expand Down
Loading

0 comments on commit 9093780

Please sign in to comment.