Skip to content

Commit

Permalink
Merge pull request #698 from talegari/add_depvar
Browse files Browse the repository at this point in the history
added depvar to result
  • Loading branch information
mnwright authored Nov 8, 2023
2 parents 4ce5894 + 0bf7876 commit 9b7e3a0
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 7 deletions.
11 changes: 11 additions & 0 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
##' \item{\code{importance.mode}}{Importance mode used.}
##' \item{\code{num.samples}}{Number of samples.}
##' \item{\code{inbag.counts}}{Number of times the observations are in-bag in the trees.}
##' \item{\code{dependent.variable.name}}{Name of the dependent variable. This is NULL when x/y interface is used.}
##' \item{\code{status.variable.name}}{Name of the status variable (survival only). This is NULL when x/y interface is used.}
##' @examples
##' ## Classification forest with default settings
##' ranger(Species ~ ., data = iris)
Expand Down Expand Up @@ -277,6 +279,10 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
stop("Error: Invalid formula.")
}
data.selected <- parse.formula(formula, data, env = parent.frame())
dependent.variable.name <- all.vars(formula)[1]
if (survival::is.Surv(data.selected[, 1])) {
status.variable.name <- all.vars(formula)[2]
}
y <- data.selected[, 1]
x <- data.selected[, -1, drop = FALSE]
}
Expand Down Expand Up @@ -1002,6 +1008,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}
}

## Dependent (and status) variable name
## will be NULL only when x/y interface is used
result$dependent.variable.name <- dependent.variable.name
result$status.variable.name <- status.variable.name

class(result) <- "ranger"

## Prepare quantile prediction
Expand Down
2 changes: 2 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ rg.class <- ranger(Species ~ ., data = iris)
rg.mat <- ranger(dependent.variable.name = "Species", data = dat, classification = TRUE)

## Basic tests (for all random forests equal)
test_that("classification result is of class ranger with 14 elements", {
test_that("classification result is of class ranger with 15 elements", {
expect_is(rg.class, "ranger")
expect_equal(length(rg.class), 14)
expect_equal(length(rg.class), 15)
})

test_that("classification prediction returns factor", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_print.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ expect_that(print(rf$forest), prints_text("Ranger forest object"))
expect_that(print(predict(rf, iris)), prints_text("Ranger prediction"))

## Test str ranger function
expect_that(str(rf), prints_text("List of 14"))
expect_that(str(rf), prints_text("List of 15"))

## Test str forest function
expect_that(str(rf$forest), prints_text("List of 9"))
4 changes: 2 additions & 2 deletions tests/testthat/test_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ context("ranger_reg")
rg.reg <- ranger(Sepal.Length ~ ., data = iris)

## Basic tests (for all random forests equal)
test_that("regression result is of class ranger with 14 elements", {
test_that("regression result is of class ranger with 15 elements", {
expect_is(rg.reg, "ranger")
expect_equal(length(rg.reg), 14)
expect_equal(length(rg.reg), 15)
})

test_that("regression prediction returns numeric vector", {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ context("ranger_surv")
rg.surv <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 10)

## Basic tests (for all random forests equal)
test_that("survival result is of class ranger with 15 elements", {
test_that("survival result is of class ranger with 17 elements", {
expect_is(rg.surv, "ranger")
expect_equal(length(rg.surv), 15)
expect_equal(length(rg.surv), 17)
})

test_that("results have right number of trees", {
Expand Down

0 comments on commit 9b7e3a0

Please sign in to comment.