Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add time.interest to reduce unique time points #700

Merged
merged 5 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.15.3
Date: 2023-07-19
Version: 0.15.4
Date: 2023-11-03
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.15.4
* Add time.interest option to restrict unique survival times (faster and saves memory)

# ranger 0.15.3
* Fix min bucket option in C++ version

Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest)
}

numSmaller <- function(values, reference) {
Expand Down
5 changes: 4 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
regularization.factor <- c(0, 0)
use.regularization.factor <- FALSE
regularization.usedepth <- FALSE
time.interest <- c(0, 0)
use.time.interest <- FALSE

## Use sparse matrix
if (inherits(x, "dgCMatrix")) {
Expand All @@ -273,7 +275,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth)
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
32 changes: 30 additions & 2 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
##' @param inbag Manually set observations per tree. List of size num.trees, containing inbag counts for each observation. Can be used for stratified sampling.
##' @param holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable importance and prediction error.
##' @param quantreg Prepare quantile prediction as in quantile regression forests (Meinshausen 2006). Regression only. Set \code{keep.inbag = TRUE} to prepare out-of-bag quantile prediction.
##' @param time.interest Time points of interest (survival only). Can be \code{NULL} (default, use all observed time points), a vector of time points or a single number to use as many time points (grid over observed time points).
##' @param oob.error Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests.
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems.
Expand Down Expand Up @@ -222,7 +223,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
local.importance = FALSE,
regularization.factor = 1, regularization.usedepth = FALSE,
keep.inbag = FALSE, inbag = NULL, holdout = FALSE,
quantreg = FALSE, oob.error = TRUE,
quantreg = FALSE, time.interest = NULL, oob.error = TRUE,
num.threads = NULL, save.memory = FALSE,
verbose = TRUE, seed = NULL,
dependent.variable.name = NULL, status.variable.name = NULL,
Expand Down Expand Up @@ -822,6 +823,32 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}
}

## Time of interest
if (is.null(time.interest)) {
time.interest <- c(0, 0)
use.time.interest <- FALSE
} else {
use.time.interest <- TRUE
if (treetype != 5) {
stop("Error: time.interest only applicable to survival forests.")
}
if (is.numeric(time.interest) & length(time.interest) == 1) {
if (time.interest < 1) {
stop("Error: time.interest must be a positive integer.")
}
# Grid over observed time points
nocens <- y[, 2] > 0
time <- sort(unique(y[nocens, 1]))
if (length(time) <= time.interest) {
time.interest <- time
} else {
time.interest <- time[unique(round(seq.int(1, length(time), length.out = time.interest)))]
}
} else {
time.interest <- sort(unique(time.interest))
}
}

## Prediction mode always false. Use predict.ranger() method.
prediction.mode <- FALSE
predict.all <- FALSE
Expand Down Expand Up @@ -873,7 +900,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type,
num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth,
inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth)
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
3 changes: 3 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 44 additions & 21 deletions src/ForestSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,43 @@ void ForestSurvival::loadForest(size_t num_trees, std::vector<std::vector<std::v
equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
}

void ForestSurvival::setUniqueTimepoints(const std::vector<double>& time_interest) {

if (time_interest.empty()) {
// Use all observed unique time points
std::set<double> unique_timepoint_set;
for (size_t i = 0; i < num_samples; ++i) {
if (data->get_y(i, 1) > 0) {
unique_timepoint_set.insert(data->get_y(i, 0));
}
}
unique_timepoints.reserve(unique_timepoint_set.size());
for (auto& t : unique_timepoint_set) {
unique_timepoints.push_back(t);
}
} else {
// Use the supplied time points of interest
unique_timepoints = time_interest;
}

// Create response_timepointIDs
for (size_t i = 0; i < num_samples; ++i) {
double value = data->get_y(i, 0);

// If timepoint is already in unique_timepoints, use ID. Else create a new one.
uint timepointID = 0;
if (value > unique_timepoints[unique_timepoints.size() - 1]) {
timepointID = unique_timepoints.size() - 1;
} else if (value > unique_timepoints[0]) {
timepointID = std::lower_bound(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin();
}
if (timepointID < 0) {
timepointID = 0;
}
response_timepointIDs.push_back(timepointID);
}
}

std::vector<std::vector<std::vector<double>>> ForestSurvival::getChf() const {
std::vector<std::vector<std::vector<double>>> result;
result.reserve(num_trees);
Expand Down Expand Up @@ -70,34 +107,20 @@ void ForestSurvival::initInternal() {
min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL;
}

// Create unique timepoints
if (!prediction_mode) {
std::set<double> unique_timepoint_set;
for (size_t i = 0; i < num_samples; ++i) {
unique_timepoint_set.insert(data->get_y(i, 0));
}
unique_timepoints.reserve(unique_timepoint_set.size());
for (auto& t : unique_timepoint_set) {
unique_timepoints.push_back(t);
}

// Create response_timepointIDs
for (size_t i = 0; i < num_samples; ++i) {
double value = data->get_y(i, 0);

// If timepoint is already in unique_timepoints, use ID. Else create a new one.
uint timepointID = find(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin();
response_timepointIDs.push_back(timepointID);
}
}

// Sort data if extratrees and not memory saving mode
if (splitrule == EXTRATREES && !memory_saving_splitting) {
data->sort();
}
}

void ForestSurvival::growInternal() {

// If unique time points not set, use observed times
if (unique_timepoints.empty()) {
setUniqueTimepoints(std::vector<double>());
}


trees.reserve(num_trees);
for (size_t i = 0; i < num_trees; ++i) {
trees.push_back(std::make_unique<TreeSurvival>(&unique_timepoints, &response_timepointIDs));
Expand Down
2 changes: 2 additions & 0 deletions src/ForestSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ForestSurvival: public Forest {
std::vector<std::vector<size_t>>& forest_split_varIDs, std::vector<std::vector<double>>& forest_split_values,
std::vector<std::vector<std::vector<double>> >& forest_chf, std::vector<double>& unique_timepoints,
std::vector<bool>& is_ordered_variable);

void setUniqueTimepoints(const std::vector<double>& time_interest);

std::vector<std::vector<std::vector<double>>> getChf() const;

Expand Down
10 changes: 6 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// rangerCpp
Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector<std::vector<double>>& split_select_weights, bool use_split_select_weights, std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector<double>& case_weights, bool use_case_weights, std::vector<double>& class_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix<double>& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector<std::vector<size_t>>& inbag, bool use_inbag, std::vector<double>& regularization_factor, bool use_regularization_factor, bool regularization_usedepth);
RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP) {
Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector<std::vector<double>>& split_select_weights, bool use_split_select_weights, std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector<double>& case_weights, bool use_case_weights, std::vector<double>& class_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix<double>& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector<std::vector<size_t>>& inbag, bool use_inbag, std::vector<double>& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, std::vector<double>& time_interest, bool use_time_interest);
RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand Down Expand Up @@ -65,7 +65,9 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< std::vector<double>& >::type regularization_factor(regularization_factorSEXP);
Rcpp::traits::input_parameter< bool >::type use_regularization_factor(use_regularization_factorSEXP);
Rcpp::traits::input_parameter< bool >::type regularization_usedepth(regularization_usedepthSEXP);
rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth));
Rcpp::traits::input_parameter< std::vector<double>& >::type time_interest(time_interestSEXP);
Rcpp::traits::input_parameter< bool >::type use_time_interest(use_time_interestSEXP);
rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -96,7 +98,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 47},
{"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 49},
{"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2},
{"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3},
{NULL, NULL, 0}
Expand Down
Loading
Loading