Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added depvar to result #698

Merged
merged 11 commits into from
Nov 8, 2023
Merged
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.15.3
Date: 2023-07-19
Version: 0.15.4
Date: 2023-11-03
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.15.4
* Add time.interest option to restrict unique survival times (faster and saves memory)

# ranger 0.15.3
* Fix min bucket option in C++ version

Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest)
}

numSmaller <- function(values, reference) {
Expand Down
5 changes: 4 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
regularization.factor <- c(0, 0)
use.regularization.factor <- FALSE
regularization.usedepth <- FALSE
time.interest <- c(0, 0)
use.time.interest <- FALSE

## Use sparse matrix
if (inherits(x, "dgCMatrix")) {
Expand All @@ -273,7 +275,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth)
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
43 changes: 41 additions & 2 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
##' @param inbag Manually set observations per tree. List of size num.trees, containing inbag counts for each observation. Can be used for stratified sampling.
##' @param holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable importance and prediction error.
##' @param quantreg Prepare quantile prediction as in quantile regression forests (Meinshausen 2006). Regression only. Set \code{keep.inbag = TRUE} to prepare out-of-bag quantile prediction.
##' @param time.interest Time points of interest (survival only). Can be \code{NULL} (default, use all observed time points), a vector of time points or a single number to use as many time points (grid over observed time points).
##' @param oob.error Compute OOB prediction error. Set to \code{FALSE} to save computation time, e.g. for large survival forests.
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems.
Expand Down Expand Up @@ -146,6 +147,8 @@
##' \item{\code{importance.mode}}{Importance mode used.}
##' \item{\code{num.samples}}{Number of samples.}
##' \item{\code{inbag.counts}}{Number of times the observations are in-bag in the trees.}
##' \item{\code{dependent.variable.name}}{Name of the dependent variable. This is NULL when x/y interface is used.}
##' \item{\code{status.variable.name}}{Name of the status variable (survival only). This is NULL when x/y interface is used.}
##' @examples
##' ## Classification forest with default settings
##' ranger(Species ~ ., data = iris)
Expand Down Expand Up @@ -222,7 +225,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
local.importance = FALSE,
regularization.factor = 1, regularization.usedepth = FALSE,
keep.inbag = FALSE, inbag = NULL, holdout = FALSE,
quantreg = FALSE, oob.error = TRUE,
quantreg = FALSE, time.interest = NULL, oob.error = TRUE,
num.threads = NULL, save.memory = FALSE,
verbose = TRUE, seed = NULL,
dependent.variable.name = NULL, status.variable.name = NULL,
Expand Down Expand Up @@ -276,6 +279,10 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
stop("Error: Invalid formula.")
}
data.selected <- parse.formula(formula, data, env = parent.frame())
dependent.variable.name <- all.vars(formula)[1]
if (survival::is.Surv(data.selected[, 1])) {
status.variable.name <- all.vars(formula)[2]
}
y <- data.selected[, 1]
x <- data.selected[, -1, drop = FALSE]
}
Expand Down Expand Up @@ -822,6 +829,32 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}
}

## Time of interest
if (is.null(time.interest)) {
time.interest <- c(0, 0)
use.time.interest <- FALSE
} else {
use.time.interest <- TRUE
if (treetype != 5) {
stop("Error: time.interest only applicable to survival forests.")
}
if (is.numeric(time.interest) & length(time.interest) == 1) {
if (time.interest < 1) {
stop("Error: time.interest must be a positive integer.")
}
# Grid over observed time points
nocens <- y[, 2] > 0
time <- sort(unique(y[nocens, 1]))
if (length(time) <= time.interest) {
time.interest <- time
} else {
time.interest <- time[unique(round(seq.int(1, length(time), length.out = time.interest)))]
}
} else {
time.interest <- sort(unique(time.interest))
}
}

## Prediction mode always false. Use predict.ranger() method.
prediction.mode <- FALSE
predict.all <- FALSE
Expand Down Expand Up @@ -873,7 +906,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type,
num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth,
inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth)
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down Expand Up @@ -974,6 +1008,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}
}

## Dependent (and status) variable name
## will be NULL only when x/y interface is used
result$dependent.variable.name <- dependent.variable.name
result$status.variable.name <- status.variable.name

class(result) <- "ranger"

## Prepare quantile prediction
Expand Down
5 changes: 5 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 44 additions & 21 deletions src/ForestSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,43 @@ void ForestSurvival::loadForest(size_t num_trees, std::vector<std::vector<std::v
equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
}

void ForestSurvival::setUniqueTimepoints(const std::vector<double>& time_interest) {

if (time_interest.empty()) {
// Use all observed unique time points
std::set<double> unique_timepoint_set;
for (size_t i = 0; i < num_samples; ++i) {
if (data->get_y(i, 1) > 0) {
unique_timepoint_set.insert(data->get_y(i, 0));
}
}
unique_timepoints.reserve(unique_timepoint_set.size());
for (auto& t : unique_timepoint_set) {
unique_timepoints.push_back(t);
}
} else {
// Use the supplied time points of interest
unique_timepoints = time_interest;
}

// Create response_timepointIDs
for (size_t i = 0; i < num_samples; ++i) {
double value = data->get_y(i, 0);

// If timepoint is already in unique_timepoints, use ID. Else create a new one.
uint timepointID = 0;
if (value > unique_timepoints[unique_timepoints.size() - 1]) {
timepointID = unique_timepoints.size() - 1;
} else if (value > unique_timepoints[0]) {
timepointID = std::lower_bound(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin();
}
if (timepointID < 0) {
timepointID = 0;
}
response_timepointIDs.push_back(timepointID);
}
}

std::vector<std::vector<std::vector<double>>> ForestSurvival::getChf() const {
std::vector<std::vector<std::vector<double>>> result;
result.reserve(num_trees);
Expand Down Expand Up @@ -70,34 +107,20 @@ void ForestSurvival::initInternal() {
min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL;
}

// Create unique timepoints
if (!prediction_mode) {
std::set<double> unique_timepoint_set;
for (size_t i = 0; i < num_samples; ++i) {
unique_timepoint_set.insert(data->get_y(i, 0));
}
unique_timepoints.reserve(unique_timepoint_set.size());
for (auto& t : unique_timepoint_set) {
unique_timepoints.push_back(t);
}

// Create response_timepointIDs
for (size_t i = 0; i < num_samples; ++i) {
double value = data->get_y(i, 0);

// If timepoint is already in unique_timepoints, use ID. Else create a new one.
uint timepointID = find(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin();
response_timepointIDs.push_back(timepointID);
}
}

// Sort data if extratrees and not memory saving mode
if (splitrule == EXTRATREES && !memory_saving_splitting) {
data->sort();
}
}

void ForestSurvival::growInternal() {

// If unique time points not set, use observed times
if (unique_timepoints.empty()) {
setUniqueTimepoints(std::vector<double>());
}


trees.reserve(num_trees);
for (size_t i = 0; i < num_trees; ++i) {
trees.push_back(std::make_unique<TreeSurvival>(&unique_timepoints, &response_timepointIDs));
Expand Down
2 changes: 2 additions & 0 deletions src/ForestSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ForestSurvival: public Forest {
std::vector<std::vector<size_t>>& forest_split_varIDs, std::vector<std::vector<double>>& forest_split_values,
std::vector<std::vector<std::vector<double>> >& forest_chf, std::vector<double>& unique_timepoints,
std::vector<bool>& is_ordered_variable);

void setUniqueTimepoints(const std::vector<double>& time_interest);

std::vector<std::vector<std::vector<double>>> getChf() const;

Expand Down
Loading
Loading