diff --git a/cpp_version/src/main.cpp b/cpp_version/src/main.cpp index b4b9d0f0..79b2b1c3 100644 --- a/cpp_version/src/main.cpp +++ b/cpp_version/src/main.cpp @@ -55,7 +55,7 @@ void run_ranger(const ArgumentHandler& arg_handler, std::ostream& verbose_out) { arg_handler.predict, arg_handler.impmeasure, arg_handler.targetpartitionsize, arg_handler.minbucket, arg_handler.splitweights, arg_handler.alwayssplitvars, arg_handler.statusvarname, arg_handler.replace, arg_handler.catvars, arg_handler.savemem, arg_handler.splitrule, arg_handler.caseweights, arg_handler.predall, arg_handler.fraction, - arg_handler.alpha, arg_handler.minprop, arg_handler.holdout, arg_handler.predictiontype, + arg_handler.alpha, arg_handler.minprop, arg_handler.tau, arg_handler.holdout, arg_handler.predictiontype, arg_handler.randomsplits, arg_handler.maxdepth, arg_handler.regcoef, arg_handler.usedepth); forest->run(true, !arg_handler.skipoob); diff --git a/cpp_version/src/utility/ArgumentHandler.cpp b/cpp_version/src/utility/ArgumentHandler.cpp index 1da4743f..d5b0a6a5 100644 --- a/cpp_version/src/utility/ArgumentHandler.cpp +++ b/cpp_version/src/utility/ArgumentHandler.cpp @@ -21,7 +21,7 @@ namespace ranger { ArgumentHandler::ArgumentHandler(int argc, char **argv) : caseweights(""), depvarname(""), fraction(0), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict( - ""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), nthreads( + ""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), tau(DEFAULT_POISSON_TAU), nthreads( DEFAULT_NUM_THREADS), predall(false), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), maxdepth( DEFAULT_MAXDEPTH), file(""), impmeasure(DEFAULT_IMPORTANCE_MODE), targetpartitionsize(0), minbucket(0), mtry(0), outprefix( "ranger_out"), probability(false), splitrule(DEFAULT_SPLITRULE), statusvarname(""), ntree(DEFAULT_NUM_TREE), replace( @@ -33,7 +33,7 @@ ArgumentHandler::ArgumentHandler(int argc, char **argv) : int ArgumentHandler::processArguments() { // short options - char const *short_options = "A:C:D:F:HM:NOP:Q:R:S:U:XZa:b:c:d:f:hi:j:kl:m:n:o:pr:s:t:uvwy:z:"; + char const *short_options = "A:C:D:F:HM:NOP:Q:R:S:T:U:XZa:b:c:d:f:hi:j:kl:m:n:o:pr:s:t:uvwy:z:"; // long options: longname, no/optional/required argument?, flag(not used!), shortname const struct option long_options[] = { @@ -50,6 +50,7 @@ int ArgumentHandler::processArguments() { { "predictiontype", required_argument, 0, 'Q'}, { "randomsplits", required_argument, 0, 'R'}, { "splitweights", required_argument, 0, 'S'}, + { "tau", required_argument, 0, 'T'}, { "nthreads", required_argument, 0, 'U'}, { "predall", no_argument, 0, 'X'}, { "version", no_argument, 0, 'Z'}, @@ -178,6 +179,20 @@ int ArgumentHandler::processArguments() { case 'S': splitweights = optarg; break; + + case 'T': + try { + double temp = std::stod(optarg); + if (temp <= 0) { + throw std::runtime_error(""); + } else { + tau = temp; + } + } catch (...) { + throw std::runtime_error( + "Illegal argument for option 'tau'. Please give a positive value. See '--help' for details."); + } + break; case 'U': try { @@ -352,6 +367,9 @@ int ArgumentHandler::processArguments() { case 7: splitrule = HELLINGER; break; + case 8: + splitrule = POISSON; + break; default: throw std::runtime_error(""); break; @@ -512,7 +530,8 @@ void ArgumentHandler::checkArguments() { if (((splitrule == AUC || splitrule == AUC_IGNORE_TIES) && treetype != TREE_SURVIVAL) || (splitrule == MAXSTAT && (treetype != TREE_SURVIVAL && treetype != TREE_REGRESSION)) || (splitrule == BETA && treetype != TREE_REGRESSION) - || (splitrule == HELLINGER && treetype != TREE_CLASSIFICATION && treetype != TREE_PROBABILITY)) { + || (splitrule == HELLINGER && treetype != TREE_CLASSIFICATION && treetype != TREE_PROBABILITY) + || (splitrule == POISSON && treetype != TREE_REGRESSION)) { throw std::runtime_error("Illegal splitrule selected. See '--help' for details."); } @@ -658,8 +677,9 @@ void ArgumentHandler::displayHelp() { << " RULE = 4: MAXSTAT for Survival and Regression, not available for Classification." << std::endl; std::cout << " " << " RULE = 5: ExtraTrees for all tree types." << std::endl; - std::cout << " " << " RULE = 6: BETA for regression, only for (0,1) bounded outcomes." << std::endl; + std::cout << " " << " RULE = 6: BETA for Regression, only for (0,1) bounded outcomes." << std::endl; std::cout << " " << " RULE = 7: Hellinger for Classification, not available for Regression and Survival." << std::endl; + std::cout << " " << " RULE = 8: Poisson for Regression, not available for Classification and Survival." << std::endl; std::cout << " " << " (Default: 1)" << std::endl; std::cout << " " << "--randomsplits N Number of random splits to consider for each splitting variable (ExtraTrees splitrule only)." @@ -670,6 +690,9 @@ void ArgumentHandler::displayHelp() { std::cout << " " << "--minprop VAL Lower quantile of covariate distribtuion to be considered for splitting (MAXSTAT splitrule only)." << std::endl; + std::cout << " " + << "--tau VAL Tau parameter for Poisson splitting (Poisson splitrule only)." + << std::endl; std::cout << " " << "--caseweights FILE Filename of case weights file." << std::endl; std::cout << " " << "--holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable " diff --git a/cpp_version/src/utility/ArgumentHandler.h b/cpp_version/src/utility/ArgumentHandler.h index d6964093..4395c95d 100644 --- a/cpp_version/src/utility/ArgumentHandler.h +++ b/cpp_version/src/utility/ArgumentHandler.h @@ -60,6 +60,7 @@ class ArgumentHandler { PredictionType predictiontype; uint randomsplits; std::string splitweights; + double tau; uint nthreads; bool predall;