diff --git a/DESCRIPTION b/DESCRIPTION index 6ca00478..66e395ac 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: ranger Type: Package Title: A Fast Implementation of Random Forests -Version: 0.15.3 -Date: 2023-07-19 +Version: 0.15.4 +Date: 2023-11-03 Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb] Maintainer: Marvin N. Wright Description: A fast implementation of Random Forests, particularly suited for high diff --git a/NEWS.md b/NEWS.md index 22a8a8d6..e07a5cc4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,7 @@ +# ranger 0.15.4 +* Add time.interest option to restrict unique survival times (faster and saves memory) + # ranger 0.15.3 * Fix min bucket option in C++ version diff --git a/R/RcppExports.R b/R/RcppExports.R index 19cc8e8a..83507c1e 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) { - .Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) +rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest) { + .Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest) } numSmaller <- function(values, reference) { diff --git a/R/predict.R b/R/predict.R index ef1397e7..c53c0da2 100644 --- a/R/predict.R +++ b/R/predict.R @@ -250,6 +250,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, regularization.factor <- c(0, 0) use.regularization.factor <- FALSE regularization.usedepth <- FALSE + time.interest <- c(0, 0) + use.time.interest <- FALSE ## Use sparse matrix if (inherits(x, "dgCMatrix")) { @@ -273,7 +275,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type, num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth, inbag, use.inbag, - regularization.factor, use.regularization.factor, regularization.usedepth) + regularization.factor, use.regularization.factor, regularization.usedepth, + time.interest, use.time.interest) if (length(result) == 0) { stop("User interrupt or internal error.") diff --git a/R/ranger.R b/R/ranger.R index 54a18342..b9be605b 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -115,6 +115,7 @@ ##' @param inbag Manually set observations per tree. List of size num.trees, containing inbag counts for each observation. Can be used for stratified sampling. ##' @param holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable importance and prediction error. ##' @param quantreg Prepare quantile prediction as in quantile regression forests (Meinshausen 2006). Regression only. Set \code{keep.inbag = TRUE} to prepare out-of-bag quantile prediction. +##' @param time.interest Time points of interest (survival only). Can be \code{NULL} (default, use all observed time points), a vector of time points or a single number to use as many time points (grid over observed time points). ##' @param oob.error Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests. ##' @param num.threads Number of threads. Default is number of CPUs available. ##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems. @@ -222,7 +223,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, local.importance = FALSE, regularization.factor = 1, regularization.usedepth = FALSE, keep.inbag = FALSE, inbag = NULL, holdout = FALSE, - quantreg = FALSE, oob.error = TRUE, + quantreg = FALSE, time.interest = NULL, oob.error = TRUE, num.threads = NULL, save.memory = FALSE, verbose = TRUE, seed = NULL, dependent.variable.name = NULL, status.variable.name = NULL, @@ -822,6 +823,32 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } } + ## Time of interest + if (is.null(time.interest)) { + time.interest <- c(0, 0) + use.time.interest <- FALSE + } else { + use.time.interest <- TRUE + if (treetype != 5) { + stop("Error: time.interest only applicable to survival forests.") + } + if (is.numeric(time.interest) & length(time.interest) == 1) { + if (time.interest < 1) { + stop("Error: time.interest must be a positive integer.") + } + # Grid over observed time points + nocens <- y[, 2] > 0 + time <- sort(unique(y[nocens, 1])) + if (length(time) <= time.interest) { + time.interest <- time + } else { + time.interest <- time[unique(round(seq.int(1, length(time), length.out = time.interest)))] + } + } else { + time.interest <- sort(unique(time.interest)) + } + } + ## Prediction mode always false. Use predict.ranger() method. prediction.mode <- FALSE predict.all <- FALSE @@ -873,7 +900,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type, num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth, inbag, use.inbag, - regularization.factor, use.regularization.factor, regularization.usedepth) + regularization.factor, use.regularization.factor, regularization.usedepth, + time.interest, use.time.interest) if (length(result) == 0) { stop("User interrupt or internal error.") diff --git a/man/ranger.Rd b/man/ranger.Rd index 63d6d395..ff202a07 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -34,6 +34,7 @@ ranger( inbag = NULL, holdout = FALSE, quantreg = FALSE, + time.interest = NULL, oob.error = TRUE, num.threads = NULL, save.memory = FALSE, @@ -106,6 +107,8 @@ ranger( \item{quantreg}{Prepare quantile prediction as in quantile regression forests (Meinshausen 2006). Regression only. Set \code{keep.inbag = TRUE} to prepare out-of-bag quantile prediction.} +\item{time.interest}{Time points of interest (survival only). Can be \code{NULL} (default, use all observed time points), a vector of time points or a single number to use as many time points (grid over observed time points).} + \item{oob.error}{Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests.} \item{num.threads}{Number of threads. Default is number of CPUs available.} diff --git a/src/ForestSurvival.cpp b/src/ForestSurvival.cpp index 0b20dff8..3b88e312 100644 --- a/src/ForestSurvival.cpp +++ b/src/ForestSurvival.cpp @@ -42,6 +42,43 @@ void ForestSurvival::loadForest(size_t num_trees, std::vector& time_interest) { + + if (time_interest.empty()) { + // Use all observed unique time points + std::set unique_timepoint_set; + for (size_t i = 0; i < num_samples; ++i) { + if (data->get_y(i, 1) > 0) { + unique_timepoint_set.insert(data->get_y(i, 0)); + } + } + unique_timepoints.reserve(unique_timepoint_set.size()); + for (auto& t : unique_timepoint_set) { + unique_timepoints.push_back(t); + } + } else { + // Use the supplied time points of interest + unique_timepoints = time_interest; + } + + // Create response_timepointIDs + for (size_t i = 0; i < num_samples; ++i) { + double value = data->get_y(i, 0); + + // If timepoint is already in unique_timepoints, use ID. Else create a new one. + uint timepointID = 0; + if (value > unique_timepoints[unique_timepoints.size() - 1]) { + timepointID = unique_timepoints.size() - 1; + } else if (value > unique_timepoints[0]) { + timepointID = std::lower_bound(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin(); + } + if (timepointID < 0) { + timepointID = 0; + } + response_timepointIDs.push_back(timepointID); + } +} + std::vector>> ForestSurvival::getChf() const { std::vector>> result; result.reserve(num_trees); @@ -70,27 +107,6 @@ void ForestSurvival::initInternal() { min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL; } - // Create unique timepoints - if (!prediction_mode) { - std::set unique_timepoint_set; - for (size_t i = 0; i < num_samples; ++i) { - unique_timepoint_set.insert(data->get_y(i, 0)); - } - unique_timepoints.reserve(unique_timepoint_set.size()); - for (auto& t : unique_timepoint_set) { - unique_timepoints.push_back(t); - } - - // Create response_timepointIDs - for (size_t i = 0; i < num_samples; ++i) { - double value = data->get_y(i, 0); - - // If timepoint is already in unique_timepoints, use ID. Else create a new one. - uint timepointID = find(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin(); - response_timepointIDs.push_back(timepointID); - } - } - // Sort data if extratrees and not memory saving mode if (splitrule == EXTRATREES && !memory_saving_splitting) { data->sort(); @@ -98,6 +114,13 @@ void ForestSurvival::initInternal() { } void ForestSurvival::growInternal() { + + // If unique time points not set, use observed times + if (unique_timepoints.empty()) { + setUniqueTimepoints(std::vector()); + } + + trees.reserve(num_trees); for (size_t i = 0; i < num_trees; ++i) { trees.push_back(std::make_unique(&unique_timepoints, &response_timepointIDs)); diff --git a/src/ForestSurvival.h b/src/ForestSurvival.h index 15b8a87f..d2efe4a8 100644 --- a/src/ForestSurvival.h +++ b/src/ForestSurvival.h @@ -34,6 +34,8 @@ class ForestSurvival: public Forest { std::vector>& forest_split_varIDs, std::vector>& forest_split_values, std::vector> >& forest_chf, std::vector& unique_timepoints, std::vector& is_ordered_variable); + + void setUniqueTimepoints(const std::vector& time_interest); std::vector>> getChf() const; diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 65b57cab..97900517 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,8 +13,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // rangerCpp -Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth); -RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP) { +Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, std::vector& time_interest, bool use_time_interest); +RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -65,7 +65,9 @@ BEGIN_RCPP Rcpp::traits::input_parameter< std::vector& >::type regularization_factor(regularization_factorSEXP); Rcpp::traits::input_parameter< bool >::type use_regularization_factor(use_regularization_factorSEXP); Rcpp::traits::input_parameter< bool >::type regularization_usedepth(regularization_usedepthSEXP); - rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth)); + Rcpp::traits::input_parameter< std::vector& >::type time_interest(time_interestSEXP); + Rcpp::traits::input_parameter< bool >::type use_time_interest(use_time_interestSEXP); + rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest)); return rcpp_result_gen; END_RCPP } @@ -96,7 +98,7 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 47}, + {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 49}, {"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2}, {"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3}, {NULL, NULL, 0} diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index e743ca15..a2baee9a 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -61,7 +61,8 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, - std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth) { + std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, + std::vector& time_interest, bool use_time_interest) { Rcpp::List result; @@ -88,6 +89,9 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM if (!use_regularization_factor) { regularization_factor.clear(); } + if (!use_time_interest) { + time_interest.clear(); + } std::ostream* verbose_out; if (verbose) { @@ -191,6 +195,12 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM auto& temp = dynamic_cast(*forest); temp.setClassWeights(class_weights); } + + // Set time points of interest + if (treetype == TREE_SURVIVAL && !time_interest.empty()) { + auto& temp = dynamic_cast(*forest); + temp.setUniqueTimepoints(time_interest); + } } // Run Ranger diff --git a/tests/testthat/test_survival.R b/tests/testthat/test_survival.R index d0fb0aaf..6676d6e0 100644 --- a/tests/testthat/test_survival.R +++ b/tests/testthat/test_survival.R @@ -57,7 +57,7 @@ test_that("predict works for single observations, survival", { ## Special tests for random forests for survival analysis test_that("unique death times in survival result is right", { - expect_equal(rg.surv$unique.death.times, sort(unique(veteran$time))) + expect_equal(rg.surv$unique.death.times, sort(unique(veteran$time[veteran$status > 0]))) }) test_that("C-index splitting works", { @@ -124,3 +124,57 @@ test_that("Survival error for competing risk data", { expect_error(ranger(y = sobj, x = veteran[, 1:2], num.trees = 5), "Error: Competing risks not supported yet\\. Use status=1 for events and status=0 for censoring\\.") }) + +test_that("Right unique time points without time.interest", { + times <- sort(unique(veteran$time[veteran$status > 0])) + + rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5) + expect_equal(timepoints(rf), times) + + rf <- ranger(y = Surv(veteran$time, veteran$status), x = veteran[, c(-3, -4)], num.trees = 5) + expect_equal(timepoints(rf), times) +}) + +test_that("time.interest results in the right number of time points", { + rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5, time.interest = 20) + expect_equal(length(timepoints(rf)), 20) + + rf <- ranger(y = Surv(veteran$time, veteran$status), x = veteran[, c(-3, -4)], + num.trees = 5, time.interest = 20) + expect_equal(length(timepoints(rf)), 20) + + rf <- ranger(y = cbind(veteran$time, veteran$status), x = veteran[, c(-3, -4)], + num.trees = 5, time.interest = 20) + expect_equal(length(timepoints(rf)), 20) + + rf <- ranger(dependent.variable.name = "time", status.variable.name = "status", + data = veteran, num.trees = 5, time.interest = 20) + expect_equal(length(timepoints(rf)), 20) +}) + +test_that("time.interest results in the right time points", { + times <- c(20, 100, 200, 1000) + + rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5, time.interest = times) + expect_equal(timepoints(rf), times) + + rf <- ranger(y = Surv(veteran$time, veteran$status), x = veteran[, c(-3, -4)], + num.trees = 5, time.interest = times) + expect_equal(timepoints(rf), times) + + rf <- ranger(y = cbind(veteran$time, veteran$status), x = veteran[, c(-3, -4)], + num.trees = 5, time.interest = times) + expect_equal(timepoints(rf), times) + + rf <- ranger(dependent.variable.name = "time", status.variable.name = "status", + data = veteran, num.trees = 5, time.interest = times) + expect_equal(timepoints(rf), times) +}) + +test_that("If more unique time points requested then observed, use observed times", { + times <- sort(unique(veteran$time[veteran$status > 0])) + rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5, time.interest = 200) + expect_equal(timepoints(rf), times) +}) + +