From a3595b1e66c2e4eebeb6ecc040ae42e6c93925d2 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Sun, 24 Sep 2023 21:38:06 +0200 Subject: [PATCH 1/5] change Eigen to Armadillo --- DESCRIPTION | 2 +- src/DataSparse.cpp | 4 ++-- src/DataSparse.h | 10 +++++----- src/RcppExports.cpp | 6 +++--- src/rangerCpp.cpp | 10 +++++----- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6ca004781..7f7d3a458 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -13,7 +13,7 @@ Description: A fast implementation of Random Forests, particularly suited for hi can be directly analyzed. License: GPL-3 Imports: Rcpp (>= 0.11.2), Matrix -LinkingTo: Rcpp, RcppEigen +LinkingTo: Rcpp, RcppArmadillo Depends: R (>= 3.1) Suggests: survival, diff --git a/src/DataSparse.cpp b/src/DataSparse.cpp index 779a54d6b..59e599ea7 100644 --- a/src/DataSparse.cpp +++ b/src/DataSparse.cpp @@ -30,10 +30,10 @@ namespace ranger { -DataSparse::DataSparse(Eigen::SparseMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, +DataSparse::DataSparse(arma::sp_mat& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, size_t num_cols) : x { }{ - this->x.swap(x); + this->x = x; this->y = y; this->variable_names = variable_names; this->num_rows = num_rows; diff --git a/src/DataSparse.h b/src/DataSparse.h index 3cd904339..de302b5cf 100644 --- a/src/DataSparse.h +++ b/src/DataSparse.h @@ -28,7 +28,7 @@ #ifndef DATASPARSE_H_ #define DATASPARSE_H_ -#include +#include #include "globals.h" #include "utility.h" @@ -40,7 +40,7 @@ class DataSparse: public Data { public: DataSparse() = default; - DataSparse(Eigen::SparseMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, + DataSparse(arma::sp_mat& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, size_t num_cols); DataSparse(const DataSparse&) = delete; @@ -54,7 +54,7 @@ class DataSparse: public Data { col = getUnpermutedVarID(col); row = getPermutedSampleID(row); } - return x.coeff(row, col); + return x(row, col); } double get_y(size_t row, size_t col) const override { @@ -67,7 +67,7 @@ class DataSparse: public Data { } void set_x(size_t col, size_t row, double value, bool& error) override { - x.coeffRef(row, col) = value; + x(row, col) = value; } void set_y(size_t col, size_t row, double value, bool& error) override { @@ -76,7 +76,7 @@ class DataSparse: public Data { // #nocov end private: - Eigen::SparseMatrix x; + arma::sp_mat x; Rcpp::NumericMatrix y; }; diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 65b57caba..484c60199 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -2,7 +2,7 @@ // Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 #include "../inst/include/ranger.h" -#include +#include #include using namespace Rcpp; @@ -13,7 +13,7 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // rangerCpp -Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth); +Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, arma::sp_mat& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth); RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; @@ -55,7 +55,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< bool >::type holdout(holdoutSEXP); Rcpp::traits::input_parameter< uint >::type prediction_type_r(prediction_type_rSEXP); Rcpp::traits::input_parameter< uint >::type num_random_splits(num_random_splitsSEXP); - Rcpp::traits::input_parameter< Eigen::SparseMatrix& >::type sparse_x(sparse_xSEXP); + Rcpp::traits::input_parameter< arma::sp_mat& >::type sparse_x(sparse_xSEXP); Rcpp::traits::input_parameter< bool >::type use_sparse_data(use_sparse_dataSEXP); Rcpp::traits::input_parameter< bool >::type order_snps(order_snpsSEXP); Rcpp::traits::input_parameter< bool >::type oob_error(oob_errorSEXP); diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index e743ca151..9bf6d0820 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -25,7 +25,7 @@ http://www.imbs-luebeck.de #-------------------------------------------------------------------------------*/ -#include +#include #include #include #include @@ -46,7 +46,7 @@ using namespace ranger; -// [[Rcpp::depends(RcppEigen)]] +// [[Rcpp::depends(RcppArmadillo)]] // [[Rcpp::export]] Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, @@ -58,7 +58,7 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, - uint num_random_splits, Eigen::SparseMatrix& sparse_x, + uint num_random_splits, arma::sp_mat& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth) { @@ -99,8 +99,8 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM size_t num_rows; size_t num_cols; if (use_sparse_data) { - num_rows = sparse_x.rows(); - num_cols = sparse_x.cols(); + num_rows = sparse_x.n_rows; + num_cols = sparse_x.n_cols; } else { num_rows = input_x.nrow(); num_cols = input_x.ncol(); From 81e3ffb8dca5ebd2010a9929e1a46747213752ee Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Mon, 25 Sep 2023 12:04:05 +0200 Subject: [PATCH 2/5] draft version of lm splitting --- DESCRIPTION | 2 +- NEWS.md | 3 +++ R/RcppExports.R | 4 ++-- R/predict.R | 5 ++++- R/ranger.R | 18 +++++++++++++++++- cpp_version/src/version.h | 2 +- man/ranger.Rd | 3 +++ src/Data.h | 10 +++++++++- src/DataRcpp.h | 32 +++++++++++++++++++++++++++++--- src/Makevars | 3 ++- src/Makevars.win | 4 +++- src/RcppExports.cpp | 10 ++++++---- src/Tree.cpp | 2 +- src/Tree.h | 4 ++-- src/TreeRegression.cpp | 27 +++++++++++++++------------ src/rangerCpp.cpp | 10 ++++++++-- 16 files changed, 106 insertions(+), 33 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7f7d3a458..779ecb9e9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: ranger Type: Package Title: A Fast Implementation of Random Forests -Version: 0.15.3 +Version: 0.15.4 Date: 2023-07-19 Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb] Maintainer: Marvin N. Wright diff --git a/NEWS.md b/NEWS.md index 22a8a8d64..5088d2371 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,7 @@ +# ranger 0.15.4 +* Add linear model residual splitting + # ranger 0.15.3 * Fix min bucket option in C++ version diff --git a/R/RcppExports.R b/R/RcppExports.R index 19cc8e8ac..cd4600040 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) { - .Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth) +rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, confounders, use_confounders) { + .Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, confounders, use_confounders) } numSmaller <- function(values, reference) { diff --git a/R/predict.R b/R/predict.R index ef1397e73..47feb1228 100644 --- a/R/predict.R +++ b/R/predict.R @@ -250,6 +250,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, regularization.factor <- c(0, 0) use.regularization.factor <- FALSE regularization.usedepth <- FALSE + confounders <- matrix(c(0, 0)) + use.confounders <- FALSE ## Use sparse matrix if (inherits(x, "dgCMatrix")) { @@ -273,7 +275,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type, num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth, inbag, use.inbag, - regularization.factor, use.regularization.factor, regularization.usedepth) + regularization.factor, use.regularization.factor, regularization.usedepth, + confounders, use.confounders) if (length(result) == 0) { stop("User interrupt or internal error.") diff --git a/R/ranger.R b/R/ranger.R index 54a18342f..cf1d346f4 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -106,6 +106,7 @@ ##' @param minprop For "maxstat" splitrule: Lower quantile of covariate distribution to be considered for splitting. ##' @param split.select.weights Numeric vector with weights between 0 and 1, used to calculate the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used. ##' @param always.split.variables Character vector with variable names to be always selected in addition to the \code{mtry} variables tried for splitting. +##' @param confounders Confounders data.frame to adjust for in regression RF. ##' @param respect.unordered.factors Handling of unordered factor covariates. One of 'ignore', 'order' and 'partition'. For the "extratrees" splitrule the default is "partition" for all other splitrules 'ignore'. Alternatively TRUE (='order') or FALSE (='ignore') can be used. See below for details. ##' @param scale.permutation.importance Scale permutation importance by standard error as in (Breiman 2001). Only applicable if permutation variable importance mode selected. ##' @param regularization.factor Regularization factor (gain penalization), either a vector of length p or one value for all variables. @@ -217,6 +218,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, case.weights = NULL, class.weights = NULL, splitrule = NULL, num.random.splits = 1, alpha = 0.5, minprop = 0.1, split.select.weights = NULL, always.split.variables = NULL, + confounders = NULL, respect.unordered.factors = NULL, scale.permutation.importance = FALSE, local.importance = FALSE, @@ -822,6 +824,19 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } } + ## Confounders + if (is.null(confounders)) { + confounders <- matrix(c(0, 0)) + use.confounders <- FALSE + } else if (is.data.frame(confounders)) { + confounders <- data.matrix(confounders) + use.confounders <- TRUE + } else if (is.matrix(confounders)) { + use.confounders <- TRUE + } else { + stop("Error: confounders argument has to be matrix or data.frame.") + } + ## Prediction mode always false. Use predict.ranger() method. prediction.mode <- FALSE predict.all <- FALSE @@ -873,7 +888,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, predict.all, keep.inbag, sample.fraction, alpha, minprop, holdout, prediction.type, num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth, inbag, use.inbag, - regularization.factor, use.regularization.factor, regularization.usedepth) + regularization.factor, use.regularization.factor, regularization.usedepth, + confounders, use.confounders) if (length(result) == 0) { stop("User interrupt or internal error.") diff --git a/cpp_version/src/version.h b/cpp_version/src/version.h index 781d7076a..2f82b8ee6 100644 --- a/cpp_version/src/version.h +++ b/cpp_version/src/version.h @@ -1,3 +1,3 @@ #ifndef RANGER_VERSION -#define RANGER_VERSION "0.15.3" +#define RANGER_VERSION "0.15.4" #endif diff --git a/man/ranger.Rd b/man/ranger.Rd index 63d6d395e..0547c1da2 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -25,6 +25,7 @@ ranger( minprop = 0.1, split.select.weights = NULL, always.split.variables = NULL, + confounders = NULL, respect.unordered.factors = NULL, scale.permutation.importance = FALSE, local.importance = FALSE, @@ -88,6 +89,8 @@ ranger( \item{always.split.variables}{Character vector with variable names to be always selected in addition to the \code{mtry} variables tried for splitting.} +\item{confounders}{Confounders data.frame to adjust for in regression RF.} + \item{respect.unordered.factors}{Handling of unordered factor covariates. One of 'ignore', 'order' and 'partition'. For the "extratrees" splitrule the default is "partition" for all other splitrules 'ignore'. Alternatively TRUE (='order') or FALSE (='ignore') can be used. See below for details.} \item{scale.permutation.importance}{Scale permutation importance by standard error as in (Breiman 2001). Only applicable if permutation variable importance mode selected.} diff --git a/src/Data.h b/src/Data.h index c58e5ec66..1d7b5b70b 100644 --- a/src/Data.h +++ b/src/Data.h @@ -33,7 +33,7 @@ class Data { virtual double get_x(size_t row, size_t col) const = 0; virtual double get_y(size_t row, size_t col) const = 0; - + size_t getVariableID(const std::string& variable_name) const; virtual void reserveMemory(size_t y_cols) = 0; @@ -196,6 +196,14 @@ class Data { order_snps = true; } // #nocov end + + virtual void lm(std::vector& sampleIDs, size_t start, size_t end) { + // Empty on purpose + } + + virtual double get_yy(size_t row, size_t col) const { + return get_y(row, col); + } protected: std::vector variable_names; diff --git a/src/DataRcpp.h b/src/DataRcpp.h index ca21561cc..9cd5acd17 100644 --- a/src/DataRcpp.h +++ b/src/DataRcpp.h @@ -27,9 +27,10 @@ Ratzeburger Allee 160 #ifndef DATARCPP_H_ #define DATARCPP_H_ - + #include - +#include + #include "globals.h" #include "utility.h" #include "Data.h" @@ -39,13 +40,16 @@ namespace ranger { class DataRcpp: public Data { public: DataRcpp() = default; - DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, size_t num_cols) { + DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, size_t num_cols, + Rcpp::NumericMatrix& confounders) { this->x = x; this->y = y; this->variable_names = variable_names; this->num_rows = num_rows; this->num_cols = num_cols; this->num_cols_no_snp = num_cols; + this->confounders = confounders; + this->resid = arma::colvec(y(Rcpp::_, 0)); } DataRcpp(const DataRcpp&) = delete; @@ -86,9 +90,31 @@ class DataRcpp: public Data { } // #nocov end + void lm(std::vector& sampleIDs, size_t start, size_t end) override { + if (confounders.size() > 0) { + std::vector idx; + idx.assign(sampleIDs.begin() + start, sampleIDs.begin() + end); + + arma::uvec ia = arma::conv_to::from(idx); + + arma::mat ca = arma::mat(confounders.begin(), confounders.nrow(), + confounders.ncol(), false); + arma::colvec ya = arma::colvec(y(Rcpp::_, 0)); + + arma::colvec coef = arma::solve(ca.rows(ia), ya(ia), arma::solve_opts::allow_ugly); + resid(ia) = ya(ia) - ca.rows(ia)*coef; + } + } + + double get_yy(size_t row, size_t col) const override { + return resid(row); + } + private: Rcpp::NumericMatrix x; Rcpp::NumericMatrix y; + arma::colvec resid; + Rcpp::NumericMatrix confounders; }; } // namespace ranger diff --git a/src/Makevars b/src/Makevars index a77f23960..41c83d0d3 100644 --- a/src/Makevars +++ b/src/Makevars @@ -1,2 +1,3 @@ -PKG_CPPFLAGS = -DR_BUILD +PKG_CPPFLAGS = -DR_BUILD $(SHLIB_OPENMP_CXXFLAGS) +PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) diff --git a/src/Makevars.win b/src/Makevars.win index a6af4dd1c..4ba80722c 100644 --- a/src/Makevars.win +++ b/src/Makevars.win @@ -1,2 +1,4 @@ -PKG_CPPFLAGS = -DR_BUILD -DWIN_R_BUILD +PKG_CPPFLAGS = -DR_BUILD -DWIN_R_BUILD $(SHLIB_OPENMP_CXXFLAGS) +PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) + diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 484c60199..663ac283f 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,8 +13,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // rangerCpp -Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, arma::sp_mat& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth); -RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP) { +Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, uint min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, arma::sp_mat& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, Rcpp::NumericMatrix confounders, bool use_confounders); +RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP confoundersSEXP, SEXP use_confoundersSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -65,7 +65,9 @@ BEGIN_RCPP Rcpp::traits::input_parameter< std::vector& >::type regularization_factor(regularization_factorSEXP); Rcpp::traits::input_parameter< bool >::type use_regularization_factor(use_regularization_factorSEXP); Rcpp::traits::input_parameter< bool >::type regularization_usedepth(regularization_usedepthSEXP); - rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth)); + Rcpp::traits::input_parameter< Rcpp::NumericMatrix >::type confounders(confoundersSEXP); + Rcpp::traits::input_parameter< bool >::type use_confounders(use_confoundersSEXP); + rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, confounders, use_confounders)); return rcpp_result_gen; END_RCPP } @@ -96,7 +98,7 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 47}, + {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 49}, {"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2}, {"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3}, {NULL, NULL, 0} diff --git a/src/Tree.cpp b/src/Tree.cpp index c6ef6303f..6d2c51c29 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -36,7 +36,7 @@ Tree::Tree(std::vector>& child_nodeIDs, std::vector& 0) { } -void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, +void Tree::init(Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, std::vector* manual_inbag, bool keep_inbag, std::vector* sample_fraction, double alpha, diff --git a/src/Tree.h b/src/Tree.h index 3acbfa20f..38636ef6a 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -35,7 +35,7 @@ class Tree { Tree(const Tree&) = delete; Tree& operator=(const Tree&) = delete; - void init(const Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, + void init(Data* data, uint mtry, size_t num_samples, uint seed, std::vector* deterministic_varIDs, std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_bucket, bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, std::vector* manual_inbag, bool keep_inbag, @@ -203,7 +203,7 @@ class Tree { std::mt19937_64 random_number_generator; // Pointer to original data - const Data* data; + Data* data; // Regularization bool regularization; diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp index 640395a6f..cbe0ae945 100644 --- a/src/TreeRegression.cpp +++ b/src/TreeRegression.cpp @@ -83,6 +83,9 @@ bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector& possi split_values[nodeID] = pure_value; return true; } + + // Fit linear model and save residuals + data->lm(sampleIDs, start_pos[nodeID], end_pos[nodeID]); // Find best split, stop if no decrease of impurity bool stop; @@ -138,7 +141,7 @@ bool TreeRegression::findBestSplit(size_t nodeID, std::vector& possible_ double sum_node = 0; for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; - sum_node += data->get_y(sampleID, 0); + sum_node += data->get_yy(sampleID, 0); } // Stop early if no split posssible @@ -223,7 +226,7 @@ void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, doubl size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), data->get_x(sampleID, varID)) - possible_split_values.begin(); - sums[idx] += data->get_y(sampleID, 0); + sums[idx] += data->get_yy(sampleID, 0); ++counter[idx]; } @@ -285,7 +288,7 @@ void TreeRegression::findBestSplitValueLargeQ(size_t nodeID, size_t varID, doubl size_t sampleID = sampleIDs[pos]; size_t index = data->getIndex(sampleID, varID); - sums[index] += data->get_y(sampleID, 0); + sums[index] += data->get_yy(sampleID, 0); ++counter[index]; } @@ -378,7 +381,7 @@ void TreeRegression::findBestSplitValueUnordered(size_t nodeID, size_t varID, do // Sum in right child for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; - double response = data->get_y(sampleID, 0); + double response = data->get_yy(sampleID, 0); double value = data->get_x(sampleID, varID); size_t factorID = floor(value) - 1; @@ -421,7 +424,7 @@ bool TreeRegression::findBestSplitMaxstat(size_t nodeID, std::vector& po response.reserve(num_samples_node); for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; - response.push_back(data->get_y(sampleID, 0)); + response.push_back(data->get_yy(sampleID, 0)); } std::vector ranks = rank(response); @@ -526,7 +529,7 @@ bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector& double sum_node = 0; for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; - sum_node += data->get_y(sampleID, 0); + sum_node += data->get_yy(sampleID, 0); } // Stop early if no split posssible @@ -612,7 +615,7 @@ void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, d for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; double value = data->get_x(sampleID, varID); - double response = data->get_y(sampleID, 0); + double response = data->get_yy(sampleID, 0); // Count samples until split_value reached for (size_t i = 0; i < num_splits; ++i) { @@ -721,7 +724,7 @@ void TreeRegression::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t // Sum in right child for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; - double response = data->get_y(sampleID, 0); + double response = data->get_yy(sampleID, 0); double value = data->get_x(sampleID, varID); size_t factorID = floor(value) - 1; @@ -766,7 +769,7 @@ bool TreeRegression::findBestSplitBeta(size_t nodeID, std::vector& possi double sum_node = 0; for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; - sum_node += data->get_y(sampleID, 0); + sum_node += data->get_yy(sampleID, 0); } // Stop early if no split posssible @@ -835,7 +838,7 @@ void TreeRegression::findBestSplitValueBeta(size_t nodeID, size_t varID, double for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; double value = data->get_x(sampleID, varID); - double response = data->get_y(sampleID, 0); + double response = data->get_yy(sampleID, 0); // Count samples until split_value reached for (size_t i = 0; i < num_splits; ++i) { @@ -874,7 +877,7 @@ void TreeRegression::findBestSplitValueBeta(size_t nodeID, size_t varID, double for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; double value = data->get_x(sampleID, varID); - double response = data->get_y(sampleID, 0); + double response = data->get_yy(sampleID, 0); if (value > possible_split_values[i]) { var_right += (response - mean_right) * (response - mean_right); @@ -900,7 +903,7 @@ void TreeRegression::findBestSplitValueBeta(size_t nodeID, size_t varID, double for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; double value = data->get_x(sampleID, varID); - double response = data->get_y(sampleID, 0); + double response = data->get_yy(sampleID, 0); if (value > possible_split_values[i]) { beta_loglik_right += betaLogLik(response, mean_right, phi_right); diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index 9bf6d0820..0b5af75d5 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -25,6 +25,8 @@ http://www.imbs-luebeck.de #-------------------------------------------------------------------------------*/ +#define ARMA_WARN_LEVEL 1 + #include #include #include @@ -61,7 +63,8 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM uint num_random_splits, arma::sp_mat& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, - std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth) { + std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, + Rcpp::NumericMatrix confounders, bool use_confounders) { Rcpp::List result; @@ -88,6 +91,9 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM if (!use_regularization_factor) { regularization_factor.clear(); } + if (!use_confounders) { + confounders = Rcpp::NumericMatrix(); + } std::ostream* verbose_out; if (verbose) { @@ -110,7 +116,7 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM if (use_sparse_data) { data = std::make_unique(sparse_x, input_y, variable_names, num_rows, num_cols); } else { - data = std::make_unique(input_x, input_y, variable_names, num_rows, num_cols); + data = std::make_unique(input_x, input_y, variable_names, num_rows, num_cols, confounders); } // If there is snp data, add it From 84d551561fb2086b4877207106c209b936011c9c Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 28 Sep 2023 10:04:57 +0200 Subject: [PATCH 3/5] lm with intercept --- R/ranger.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/ranger.R b/R/ranger.R index cf1d346f4..33da7f353 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -829,9 +829,10 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, confounders <- matrix(c(0, 0)) use.confounders <- FALSE } else if (is.data.frame(confounders)) { - confounders <- data.matrix(confounders) + confounders <- cbind(1, data.matrix(confounders)) use.confounders <- TRUE } else if (is.matrix(confounders)) { + confounders <- cbind(1, confounders) use.confounders <- TRUE } else { stop("Error: confounders argument has to be matrix or data.frame.") From a01103ee31b1eeec116dc4ab7f86c5f3e942940a Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Mon, 20 Nov 2023 21:59:49 +0100 Subject: [PATCH 4/5] add GLM in terminal nodes for prediction and variable importance --- R/predict.R | 28 +++++++++++++++++++++++++++- man/predict.ranger.forest.Rd | 1 + src/Data.h | 10 ++++++++++ src/DataRcpp.h | 23 +++++++++++++++++++++++ src/ForestRegression.cpp | 10 ++++++++++ src/ForestRegression.h | 2 ++ src/TreeRegression.cpp | 14 +++++++++++++- src/TreeRegression.h | 7 +++++++ src/rangerCpp.cpp | 2 ++ 9 files changed, 95 insertions(+), 2 deletions(-) diff --git a/R/predict.R b/R/predict.R index a6c5d974f..3d9d04060 100644 --- a/R/predict.R +++ b/R/predict.R @@ -73,7 +73,8 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, num.trees = object$num.trees, type = "response", se.method = "infjack", seed = NULL, num.threads = NULL, - verbose = TRUE, inbag.counts = NULL, ...) { + verbose = TRUE, inbag.counts = NULL, + confounders = NULL, ...) { ## GenABEL GWA data if (inherits(data, "gwaa.data")) { @@ -122,6 +123,31 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, stop("Error: Invalid value for 'type'. Use 'response', 'se', 'terminalNodes', or 'quantiles'.") } + if (!is.null(confounders)) { + if (is.null(forest$glm.coefs) || length(unlist(forest$glm.coefs)) == 0) { + stop("For glm prediction, fit a regression RF with the confounders argument.") + } + if (is.data.frame(confounders)) { + confounders <- cbind(1, data.matrix(confounders)) + } else if (is.matrix(confounders)) { + confounders <- cbind(1, confounders) + } else { + stop("Error: confounders argument has to be matrix or data.frame.") + } + nodes <- predict(object = object, data = data, predict.all = predict.all, + num.trees = num.trees, type = "terminalNodes", se.method = se.method, + seed = seed, num.threads = num.threads, + verbose = verbose, inbag.counts = inbag.counts, ...)$predictions + + pred <- sapply(1:num.trees, function(i) { + tree_coefs <- forest$glm.coefs[[i]][nodes[, i] + 1] + sapply(1:length(tree_coefs), function(j) { + confounders[j, ] %*% tree_coefs[[j]] + }) + }) + return(pred) + } + ## Type "se" only for certain tree types if (type == "se" && se.method == "jack" && forest$treetype != "Regression") { stop("Error: Jackknife standard error prediction currently only available for regression.") diff --git a/man/predict.ranger.forest.Rd b/man/predict.ranger.forest.Rd index ba018b0e3..0a99def04 100644 --- a/man/predict.ranger.forest.Rd +++ b/man/predict.ranger.forest.Rd @@ -15,6 +15,7 @@ num.threads = NULL, verbose = TRUE, inbag.counts = NULL, + confounders = NULL, ... ) } diff --git a/src/Data.h b/src/Data.h index 1d7b5b70b..7ca83c1ec 100644 --- a/src/Data.h +++ b/src/Data.h @@ -201,6 +201,16 @@ class Data { // Empty on purpose } + virtual std::vector lm_coefs(std::vector& sampleIDs, size_t start, size_t end) { + // Empty on purpose + return std::vector(); + } + + virtual double predict(size_t row, std::vector coefs) { + // Empty on purpose + return 0; + } + virtual double get_yy(size_t row, size_t col) const { return get_y(row, col); } diff --git a/src/DataRcpp.h b/src/DataRcpp.h index 9cd5acd17..0e65b2906 100644 --- a/src/DataRcpp.h +++ b/src/DataRcpp.h @@ -106,6 +106,29 @@ class DataRcpp: public Data { } } + std::vector lm_coefs(std::vector& sampleIDs, size_t start, size_t end) override { + if (confounders.size() > 0) { + std::vector idx; + idx.assign(sampleIDs.begin() + start, sampleIDs.begin() + end); + + arma::uvec ia = arma::conv_to::from(idx); + + arma::mat ca = arma::mat(confounders.begin(), confounders.nrow(), + confounders.ncol(), false); + arma::colvec ya = arma::colvec(y(Rcpp::_, 0)); + + arma::colvec coef = arma::solve(ca.rows(ia), ya(ia), arma::solve_opts::allow_ugly); + + return arma::conv_to>::from(coef); + } else { + return std::vector(); + } + } + + double predict(size_t row, std::vector coefs) override { + return arma::dot(arma::vec(confounders(row, Rcpp::_)), arma::vec(coefs)); + } + double get_yy(size_t row, size_t col) const override { return resid(row); } diff --git a/src/ForestRegression.cpp b/src/ForestRegression.cpp index 7c1bb3269..37bb1df98 100644 --- a/src/ForestRegression.cpp +++ b/src/ForestRegression.cpp @@ -39,6 +39,16 @@ void ForestRegression::loadForest(size_t num_trees, equalSplit(thread_ranges, 0, num_trees - 1, num_threads); } +std::vector>> ForestRegression::getGlmCoefs() const { + std::vector>> result; + result.reserve(num_trees); + for (const auto& tree : trees) { + const auto& temp = dynamic_cast(*tree); + result.push_back(temp.getGlmCoefs()); + } + return result; +} + void ForestRegression::initInternal() { // If mtry not set, use floored square root of number of independent variables diff --git a/src/ForestRegression.h b/src/ForestRegression.h index 62689d38e..313b631fa 100644 --- a/src/ForestRegression.h +++ b/src/ForestRegression.h @@ -32,6 +32,8 @@ class ForestRegression: public Forest { void loadForest(size_t num_trees, std::vector> >& forest_child_nodeIDs, std::vector>& forest_split_varIDs, std::vector>& forest_split_values, std::vector& is_ordered_variable); + + std::vector>> getGlmCoefs() const; private: void initInternal() override; diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp index 45c5cf1a4..937a04bf8 100644 --- a/src/TreeRegression.cpp +++ b/src/TreeRegression.cpp @@ -42,6 +42,8 @@ void TreeRegression::allocateMemory() { } double TreeRegression::estimate(size_t nodeID) { + + glm_coefs[nodeID] = data->lm_coefs(sampleIDs, start_pos[nodeID], end_pos[nodeID]); // Mean of responses of samples in node double sum_responses_in_node = 0; @@ -87,6 +89,7 @@ bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector& possi } if (pure) { split_values[nodeID] = pure_value; + glm_coefs[nodeID] = data->lm_coefs(sampleIDs, start_pos[nodeID], end_pos[nodeID]); return true; } @@ -117,6 +120,7 @@ void TreeRegression::createEmptyNodeInternal() { if (save_node_stats) { node_predictions.push_back(0); } + glm_coefs.push_back(std::vector()); } double TreeRegression::computePredictionAccuracyInternal(std::vector* prediction_error_casewise) { @@ -125,7 +129,15 @@ double TreeRegression::computePredictionAccuracyInternal(std::vector* pr double sum_of_squares = 0; for (size_t i = 0; i < num_predictions; ++i) { size_t terminal_nodeID = prediction_terminal_nodeIDs[i]; - double predicted_value = split_values[terminal_nodeID]; + + double predicted_value; + if (glm_coefs[terminal_nodeID].size() > 0) { + // Get predicted value from glm in terminal node + predicted_value = data->predict(oob_sampleIDs[i], glm_coefs[terminal_nodeID]); + } else { + predicted_value = split_values[terminal_nodeID]; + } + double real_value = data->get_y(oob_sampleIDs[i], 0); if (predicted_value != real_value) { double diff = (predicted_value - real_value) * (predicted_value - real_value); diff --git a/src/TreeRegression.h b/src/TreeRegression.h index 84c224f63..7a05e3786 100644 --- a/src/TreeRegression.h +++ b/src/TreeRegression.h @@ -46,6 +46,10 @@ class TreeRegression: public Tree { size_t getPredictionTerminalNodeID(size_t sampleID) const { return prediction_terminal_nodeIDs[sampleID]; } + + const std::vector>& getGlmCoefs() const { + return glm_coefs; + } private: bool splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) override; @@ -93,6 +97,9 @@ class TreeRegression: public Tree { sums.clear(); sums.shrink_to_fit(); } + + // GLM coefficients in terminal nodes. Empty for non-terminal nodes (except if save_node_stats). + std::vector> glm_coefs; std::vector counter; std::vector sums; diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index 515f6ff25..d2215a4e7 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -285,6 +285,8 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM if (node_stats) { forest_object.push_back(forest->getNodePredictions(), "node.predictions"); } + auto& temp = dynamic_cast(*forest); + forest_object.push_back(temp.getGlmCoefs(), "glm.coefs"); } else if (treetype == TREE_PROBABILITY) { auto& temp = dynamic_cast(*forest); forest_object.push_back(temp.getClassValues(), "class.values"); From a91d4cad02c75ca2af486e67d9229b78a822c6ae Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Tue, 21 Nov 2023 13:59:12 +0100 Subject: [PATCH 5/5] handle categorical features --- R/predict.R | 2 +- R/ranger.R | 2 +- src/DataRcpp.h | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/predict.R b/R/predict.R index 3d9d04060..bd0d2c233 100644 --- a/R/predict.R +++ b/R/predict.R @@ -128,7 +128,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, stop("For glm prediction, fit a regression RF with the confounders argument.") } if (is.data.frame(confounders)) { - confounders <- cbind(1, data.matrix(confounders)) + confounders <- model.matrix( ~ ., confounders) } else if (is.matrix(confounders)) { confounders <- cbind(1, confounders) } else { diff --git a/R/ranger.R b/R/ranger.R index c48962fd1..2f6a41c22 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -857,7 +857,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, confounders <- matrix(c(0, 0)) use.confounders <- FALSE } else if (is.data.frame(confounders)) { - confounders <- cbind(1, data.matrix(confounders)) + confounders <- model.matrix( ~ ., confounders) use.confounders <- TRUE } else if (is.matrix(confounders)) { confounders <- cbind(1, confounders) diff --git a/src/DataRcpp.h b/src/DataRcpp.h index 0e65b2906..cf79b29d0 100644 --- a/src/DataRcpp.h +++ b/src/DataRcpp.h @@ -101,7 +101,7 @@ class DataRcpp: public Data { confounders.ncol(), false); arma::colvec ya = arma::colvec(y(Rcpp::_, 0)); - arma::colvec coef = arma::solve(ca.rows(ia), ya(ia), arma::solve_opts::allow_ugly); + arma::colvec coef = arma::solve(ca.rows(ia), ya(ia)); resid(ia) = ya(ia) - ca.rows(ia)*coef; } } @@ -117,7 +117,7 @@ class DataRcpp: public Data { confounders.ncol(), false); arma::colvec ya = arma::colvec(y(Rcpp::_, 0)); - arma::colvec coef = arma::solve(ca.rows(ia), ya(ia), arma::solve_opts::allow_ugly); + arma::colvec coef = arma::solve(ca.rows(ia), ya(ia)); return arma::conv_to>::from(coef); } else {