Skip to content

Commit

Permalink
add Poisson splitrule to pure C++ version
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed May 16, 2024
1 parent ed2b73d commit 24bd170
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cpp_version/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void run_ranger(const ArgumentHandler& arg_handler, std::ostream& verbose_out) {
arg_handler.predict, arg_handler.impmeasure, arg_handler.targetpartitionsize, arg_handler.minbucket, arg_handler.splitweights,
arg_handler.alwayssplitvars, arg_handler.statusvarname, arg_handler.replace, arg_handler.catvars,
arg_handler.savemem, arg_handler.splitrule, arg_handler.caseweights, arg_handler.predall, arg_handler.fraction,
arg_handler.alpha, arg_handler.minprop, arg_handler.holdout, arg_handler.predictiontype,
arg_handler.alpha, arg_handler.minprop, arg_handler.tau, arg_handler.holdout, arg_handler.predictiontype,
arg_handler.randomsplits, arg_handler.maxdepth, arg_handler.regcoef, arg_handler.usedepth);

forest->run(true, !arg_handler.skipoob);
Expand Down
31 changes: 27 additions & 4 deletions cpp_version/src/utility/ArgumentHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace ranger {

ArgumentHandler::ArgumentHandler(int argc, char **argv) :
caseweights(""), depvarname(""), fraction(0), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict(
""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), nthreads(
""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), tau(DEFAULT_POISSON_TAU), nthreads(
DEFAULT_NUM_THREADS), predall(false), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), maxdepth(
DEFAULT_MAXDEPTH), file(""), impmeasure(DEFAULT_IMPORTANCE_MODE), targetpartitionsize(0), minbucket(0), mtry(0), outprefix(
"ranger_out"), probability(false), splitrule(DEFAULT_SPLITRULE), statusvarname(""), ntree(DEFAULT_NUM_TREE), replace(
Expand All @@ -33,7 +33,7 @@ ArgumentHandler::ArgumentHandler(int argc, char **argv) :
int ArgumentHandler::processArguments() {

// short options
char const *short_options = "A:C:D:F:HM:NOP:Q:R:S:U:XZa:b:c:d:f:hi:j:kl:m:n:o:pr:s:t:uvwy:z:";
char const *short_options = "A:C:D:F:HM:NOP:Q:R:S:T:U:XZa:b:c:d:f:hi:j:kl:m:n:o:pr:s:t:uvwy:z:";

// long options: longname, no/optional/required argument?, flag(not used!), shortname
const struct option long_options[] = {
Expand All @@ -50,6 +50,7 @@ int ArgumentHandler::processArguments() {
{ "predictiontype", required_argument, 0, 'Q'},
{ "randomsplits", required_argument, 0, 'R'},
{ "splitweights", required_argument, 0, 'S'},
{ "tau", required_argument, 0, 'T'},
{ "nthreads", required_argument, 0, 'U'},
{ "predall", no_argument, 0, 'X'},
{ "version", no_argument, 0, 'Z'},
Expand Down Expand Up @@ -178,6 +179,20 @@ int ArgumentHandler::processArguments() {
case 'S':
splitweights = optarg;
break;

case 'T':
try {
double temp = std::stod(optarg);
if (temp <= 0) {
throw std::runtime_error("");
} else {
tau = temp;
}
} catch (...) {
throw std::runtime_error(
"Illegal argument for option 'tau'. Please give a positive value. See '--help' for details.");
}
break;

case 'U':
try {
Expand Down Expand Up @@ -352,6 +367,9 @@ int ArgumentHandler::processArguments() {
case 7:
splitrule = HELLINGER;
break;
case 8:
splitrule = POISSON;
break;
default:
throw std::runtime_error("");
break;
Expand Down Expand Up @@ -512,7 +530,8 @@ void ArgumentHandler::checkArguments() {
if (((splitrule == AUC || splitrule == AUC_IGNORE_TIES) && treetype != TREE_SURVIVAL)
|| (splitrule == MAXSTAT && (treetype != TREE_SURVIVAL && treetype != TREE_REGRESSION))
|| (splitrule == BETA && treetype != TREE_REGRESSION)
|| (splitrule == HELLINGER && treetype != TREE_CLASSIFICATION && treetype != TREE_PROBABILITY)) {
|| (splitrule == HELLINGER && treetype != TREE_CLASSIFICATION && treetype != TREE_PROBABILITY)
|| (splitrule == POISSON && treetype != TREE_REGRESSION)) {
throw std::runtime_error("Illegal splitrule selected. See '--help' for details.");
}

Expand Down Expand Up @@ -658,8 +677,9 @@ void ArgumentHandler::displayHelp() {
<< " RULE = 4: MAXSTAT for Survival and Regression, not available for Classification."
<< std::endl;
std::cout << " " << " RULE = 5: ExtraTrees for all tree types." << std::endl;
std::cout << " " << " RULE = 6: BETA for regression, only for (0,1) bounded outcomes." << std::endl;
std::cout << " " << " RULE = 6: BETA for Regression, only for (0,1) bounded outcomes." << std::endl;
std::cout << " " << " RULE = 7: Hellinger for Classification, not available for Regression and Survival." << std::endl;
std::cout << " " << " RULE = 8: Poisson for Regression, not available for Classification and Survival." << std::endl;
std::cout << " " << " (Default: 1)" << std::endl;
std::cout << " "
<< "--randomsplits N Number of random splits to consider for each splitting variable (ExtraTrees splitrule only)."
Expand All @@ -670,6 +690,9 @@ void ArgumentHandler::displayHelp() {
std::cout << " "
<< "--minprop VAL Lower quantile of covariate distribtuion to be considered for splitting (MAXSTAT splitrule only)."
<< std::endl;
std::cout << " "
<< "--tau VAL Tau parameter for Poisson splitting (Poisson splitrule only)."
<< std::endl;
std::cout << " " << "--caseweights FILE Filename of case weights file." << std::endl;
std::cout << " "
<< "--holdout Hold-out mode. Hold-out all samples with case weight 0 and use these for variable "
Expand Down
1 change: 1 addition & 0 deletions cpp_version/src/utility/ArgumentHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ArgumentHandler {
PredictionType predictiontype;
uint randomsplits;
std::string splitweights;
double tau;
uint nthreads;
bool predall;

Expand Down

0 comments on commit 24bd170

Please sign in to comment.