Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft linear model residual splitting #691

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.15.4
Date: 2023-11-07
Version: 0.16.1
Date: 2023-11-20
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand All @@ -13,7 +13,7 @@ Description: A fast implementation of Random Forests, particularly suited for hi
can be directly analyzed.
License: GPL-3
Imports: Rcpp (>= 0.11.2), Matrix
LinkingTo: Rcpp, RcppEigen
LinkingTo: Rcpp, RcppArmadillo
Depends: R (>= 3.1)
Suggests:
survival,
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@

# ranger 0.16.1
* Add linear model residual splitting

# ranger 0.16.0
* New CRAN version

# ranger 0.15.4
* Add node.stats option to save node statistics of all nodes
* Add time.interest option to restrict unique survival times (faster and saves memory)
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, confounders, use_confounders) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, confounders, use_confounders)
}

numSmaller <- function(values, reference) {
Expand Down
32 changes: 30 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
num.trees = object$num.trees,
type = "response", se.method = "infjack",
seed = NULL, num.threads = NULL,
verbose = TRUE, inbag.counts = NULL, ...) {
verbose = TRUE, inbag.counts = NULL,
confounders = NULL, ...) {

## GenABEL GWA data
if (inherits(data, "gwaa.data")) {
Expand Down Expand Up @@ -122,6 +123,31 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
stop("Error: Invalid value for 'type'. Use 'response', 'se', 'terminalNodes', or 'quantiles'.")
}

if (!is.null(confounders)) {
if (is.null(forest$glm.coefs) || length(unlist(forest$glm.coefs)) == 0) {
stop("For glm prediction, fit a regression RF with the confounders argument.")
}
if (is.data.frame(confounders)) {
confounders <- model.matrix( ~ ., confounders)
} else if (is.matrix(confounders)) {
confounders <- cbind(1, confounders)
} else {
stop("Error: confounders argument has to be matrix or data.frame.")
}
nodes <- predict(object = object, data = data, predict.all = predict.all,
num.trees = num.trees, type = "terminalNodes", se.method = se.method,
seed = seed, num.threads = num.threads,
verbose = verbose, inbag.counts = inbag.counts, ...)$predictions

pred <- sapply(1:num.trees, function(i) {
tree_coefs <- forest$glm.coefs[[i]][nodes[, i] + 1]
sapply(1:length(tree_coefs), function(j) {
confounders[j, ] %*% tree_coefs[[j]]
})
})
return(pred)
}

## Type "se" only for certain tree types
if (type == "se" && se.method == "jack" && forest$treetype != "Regression") {
stop("Error: Jackknife standard error prediction currently only available for regression.")
Expand Down Expand Up @@ -250,6 +276,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
regularization.factor <- c(0, 0)
use.regularization.factor <- FALSE
regularization.usedepth <- FALSE
confounders <- matrix(c(0, 0))
use.confounders <- FALSE
node.stats <- FALSE
time.interest <- c(0, 0)
use.time.interest <- FALSE
Expand Down Expand Up @@ -277,7 +305,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
node.stats, time.interest, use.time.interest)
node.stats, time.interest, use.time.interest, confounders, use.confounders)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
18 changes: 17 additions & 1 deletion R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
##' @param minprop For "maxstat" splitrule: Lower quantile of covariate distribution to be considered for splitting.
##' @param split.select.weights Numeric vector with weights between 0 and 1, used to calculate the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used.
##' @param always.split.variables Character vector with variable names to be always selected in addition to the \code{mtry} variables tried for splitting.
##' @param confounders Confounders data.frame to adjust for in regression RF.
##' @param respect.unordered.factors Handling of unordered factor covariates. One of 'ignore', 'order' and 'partition'. For the "extratrees" splitrule the default is "partition" for all other splitrules 'ignore'. Alternatively TRUE (='order') or FALSE (='ignore') can be used. See below for details.
##' @param scale.permutation.importance Scale permutation importance by standard error as in (Breiman 2001). Only applicable if permutation variable importance mode selected.
##' @param regularization.factor Regularization factor (gain penalization), either a vector of length p or one value for all variables.
Expand Down Expand Up @@ -238,6 +239,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
case.weights = NULL, class.weights = NULL, splitrule = NULL,
num.random.splits = 1, alpha = 0.5, minprop = 0.1,
split.select.weights = NULL, always.split.variables = NULL,
confounders = NULL,
respect.unordered.factors = NULL,
scale.permutation.importance = FALSE,
local.importance = FALSE,
Expand Down Expand Up @@ -850,6 +852,20 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}
}

## Confounders
if (is.null(confounders)) {
confounders <- matrix(c(0, 0))
use.confounders <- FALSE
} else if (is.data.frame(confounders)) {
confounders <- model.matrix( ~ ., confounders)
use.confounders <- TRUE
} else if (is.matrix(confounders)) {
confounders <- cbind(1, confounders)
use.confounders <- TRUE
} else {
stop("Error: confounders argument has to be matrix or data.frame.")
}

## Time of interest
if (is.null(time.interest)) {
time.interest <- c(0, 0)
Expand Down Expand Up @@ -928,7 +944,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth,
inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
node.stats, time.interest, use.time.interest)
node.stats, time.interest, use.time.interest, confounders, use.confounders)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
1 change: 1 addition & 0 deletions man/predict.ranger.forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 19 additions & 1 deletion src/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Data {

virtual double get_x(size_t row, size_t col) const = 0;
virtual double get_y(size_t row, size_t col) const = 0;

size_t getVariableID(const std::string& variable_name) const;

virtual void reserveMemory(size_t y_cols) = 0;
Expand Down Expand Up @@ -196,6 +196,24 @@ class Data {
order_snps = true;
}
// #nocov end

virtual void lm(std::vector<size_t>& sampleIDs, size_t start, size_t end) {
// Empty on purpose
}

virtual std::vector<double> lm_coefs(std::vector<size_t>& sampleIDs, size_t start, size_t end) {
// Empty on purpose
return std::vector<double>();
}

virtual double predict(size_t row, std::vector<double> coefs) {
// Empty on purpose
return 0;
}

virtual double get_yy(size_t row, size_t col) const {
return get_y(row, col);
}

protected:
std::vector<std::string> variable_names;
Expand Down
55 changes: 52 additions & 3 deletions src/DataRcpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ Ratzeburger Allee 160

#ifndef DATARCPP_H_
#define DATARCPP_H_

#include <Rcpp.h>

#include <RcppArmadillo.h>

#include "globals.h"
#include "utility.h"
#include "Data.h"
Expand All @@ -39,13 +40,16 @@ namespace ranger {
class DataRcpp: public Data {
public:
DataRcpp() = default;
DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows, size_t num_cols) {
DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows, size_t num_cols,
Rcpp::NumericMatrix& confounders) {
this->x = x;
this->y = y;
this->variable_names = variable_names;
this->num_rows = num_rows;
this->num_cols = num_cols;
this->num_cols_no_snp = num_cols;
this->confounders = confounders;
this->resid = arma::colvec(y(Rcpp::_, 0));
}

DataRcpp(const DataRcpp&) = delete;
Expand Down Expand Up @@ -86,9 +90,54 @@ class DataRcpp: public Data {
}
// #nocov end

void lm(std::vector<size_t>& sampleIDs, size_t start, size_t end) override {
if (confounders.size() > 0) {
std::vector<size_t> idx;
idx.assign(sampleIDs.begin() + start, sampleIDs.begin() + end);

arma::uvec ia = arma::conv_to<arma::uvec>::from(idx);

arma::mat ca = arma::mat(confounders.begin(), confounders.nrow(),
confounders.ncol(), false);
arma::colvec ya = arma::colvec(y(Rcpp::_, 0));

arma::colvec coef = arma::solve(ca.rows(ia), ya(ia));
resid(ia) = ya(ia) - ca.rows(ia)*coef;
}
}

std::vector<double> lm_coefs(std::vector<size_t>& sampleIDs, size_t start, size_t end) override {
if (confounders.size() > 0) {
std::vector<size_t> idx;
idx.assign(sampleIDs.begin() + start, sampleIDs.begin() + end);

arma::uvec ia = arma::conv_to<arma::uvec>::from(idx);

arma::mat ca = arma::mat(confounders.begin(), confounders.nrow(),
confounders.ncol(), false);
arma::colvec ya = arma::colvec(y(Rcpp::_, 0));

arma::colvec coef = arma::solve(ca.rows(ia), ya(ia));

return arma::conv_to<std::vector<double>>::from(coef);
} else {
return std::vector<double>();
}
}

double predict(size_t row, std::vector<double> coefs) override {
return arma::dot(arma::vec(confounders(row, Rcpp::_)), arma::vec(coefs));
}

double get_yy(size_t row, size_t col) const override {
return resid(row);
}

private:
Rcpp::NumericMatrix x;
Rcpp::NumericMatrix y;
arma::colvec resid;
Rcpp::NumericMatrix confounders;
};

} // namespace ranger
Expand Down
4 changes: 2 additions & 2 deletions src/DataSparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@

namespace ranger {

DataSparse::DataSparse(Eigen::SparseMatrix<double>& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows,
DataSparse::DataSparse(arma::sp_mat& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows,
size_t num_cols) :
x { }{
this->x.swap(x);
this->x = x;
this->y = y;
this->variable_names = variable_names;
this->num_rows = num_rows;
Expand Down
10 changes: 5 additions & 5 deletions src/DataSparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#ifndef DATASPARSE_H_
#define DATASPARSE_H_

#include <RcppEigen.h>
#include <RcppArmadillo.h>

#include "globals.h"
#include "utility.h"
Expand All @@ -40,7 +40,7 @@ class DataSparse: public Data {
public:
DataSparse() = default;

DataSparse(Eigen::SparseMatrix<double>& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows,
DataSparse(arma::sp_mat& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows,
size_t num_cols);

DataSparse(const DataSparse&) = delete;
Expand All @@ -54,7 +54,7 @@ class DataSparse: public Data {
col = getUnpermutedVarID(col);
row = getPermutedSampleID(row);
}
return x.coeff(row, col);
return x(row, col);
}

double get_y(size_t row, size_t col) const override {
Expand All @@ -67,7 +67,7 @@ class DataSparse: public Data {
}

void set_x(size_t col, size_t row, double value, bool& error) override {
x.coeffRef(row, col) = value;
x(row, col) = value;
}

void set_y(size_t col, size_t row, double value, bool& error) override {
Expand All @@ -76,7 +76,7 @@ class DataSparse: public Data {
// #nocov end

private:
Eigen::SparseMatrix<double> x;
arma::sp_mat x;
Rcpp::NumericMatrix y;
};

Expand Down
10 changes: 10 additions & 0 deletions src/ForestRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ void ForestRegression::loadForest(size_t num_trees,
equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
}

std::vector<std::vector<std::vector<double>>> ForestRegression::getGlmCoefs() const {
std::vector<std::vector<std::vector<double>>> result;
result.reserve(num_trees);
for (const auto& tree : trees) {
const auto& temp = dynamic_cast<const TreeRegression&>(*tree);
result.push_back(temp.getGlmCoefs());
}
return result;
}

void ForestRegression::initInternal() {

// If mtry not set, use floored square root of number of independent variables
Expand Down
2 changes: 2 additions & 0 deletions src/ForestRegression.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class ForestRegression: public Forest {
void loadForest(size_t num_trees, std::vector<std::vector<std::vector<size_t>> >& forest_child_nodeIDs,
std::vector<std::vector<size_t>>& forest_split_varIDs, std::vector<std::vector<double>>& forest_split_values,
std::vector<bool>& is_ordered_variable);

std::vector<std::vector<std::vector<double>>> getGlmCoefs() const;

private:
void initInternal() override;
Expand Down
3 changes: 2 additions & 1 deletion src/Makevars
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
PKG_CPPFLAGS = -DR_BUILD
PKG_CPPFLAGS = -DR_BUILD $(SHLIB_OPENMP_CXXFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)

4 changes: 3 additions & 1 deletion src/Makevars.win
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
PKG_CPPFLAGS = -DR_BUILD -DWIN_R_BUILD
PKG_CPPFLAGS = -DR_BUILD -DWIN_R_BUILD $(SHLIB_OPENMP_CXXFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)


Loading