Skip to content

Commit

Permalink
update tests based on recent parsnip testing changes (#35)
Browse files Browse the repository at this point in the history
* update tests based on recent parsnip testing changes

* add generated snapshot

* add glm execution tests back in

Co-authored-by: simonpcouch <[email protected]>
  • Loading branch information
topepo and simonpcouch authored Jun 14, 2022
1 parent 15b0904 commit d42c32c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 162 deletions.
3 changes: 3 additions & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ glm
stan
warmup
ORCID
funder
poisson
zeroinfl
1 change: 1 addition & 0 deletions poissonreg.Rproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ LineEndingConversion: Posix

BuildType: Package
PackageUseDevtools: Yes
PackageCleanBeforeInstall: Yes
PackageInstallArgs: --no-multiarch --with-keep.source
PackageRoxygenize: rd,collate,namespace
18 changes: 18 additions & 0 deletions tests/testthat/_snaps/poisson-reg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# updating

Code
poisson_reg(penalty = 1) %>% set_engine("glmnet", lambda.min.ratio = 0.001) %>%
update(mixture = tune())
Output
Poisson Regression Model Specification (regression)
Main Arguments:
penalty = 1
mixture = tune()
Engine-Specific Arguments:
lambda.min.ratio = 0.001
Computational engine: glmnet

171 changes: 9 additions & 162 deletions tests/testthat/test-poisson-reg.R
Original file line number Diff line number Diff line change
@@ -1,163 +1,19 @@
test_that("primary arguments", {
basic <- poisson_reg()
basic_lm <- translate(basic %>% set_engine("glm"))
basic_glmnet <- translate(basic %>% set_engine("glmnet"))
basic_stan <- translate(basic %>% set_engine("stan"))
expect_equal(
basic_lm$method$fit$args,
list(
formula = expr(missing_arg()),
data = expr(missing_arg()),
weights = expr(missing_arg()),
family = expr(stats::poisson)
)
)
expect_equal(
basic_glmnet$method$fit$args,
list(
x = expr(missing_arg()),
y = expr(missing_arg()),
weights = expr(missing_arg()),
family = "poisson"
)
)
expect_equal(
basic_stan$method$fit$args,
list(
formula = expr(missing_arg()),
data = expr(missing_arg()),
weights = expr(missing_arg()),
family = expr(stats::poisson)
)
)

mixture <- poisson_reg(mixture = 0.128)
mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
expect_equal(
mixture_glmnet$method$fit$args,
list(
x = expr(missing_arg()),
y = expr(missing_arg()),
weights = expr(missing_arg()),
alpha = new_empty_quosure(0.128),
family = "poisson"
)
)

penalty <- poisson_reg(penalty = 1)
penalty_glmnet <- translate(penalty %>% set_engine("glmnet"))
expect_equal(
penalty_glmnet$method$fit$args,
list(
x = expr(missing_arg()),
y = expr(missing_arg()),
weights = expr(missing_arg()),
family = "poisson"
)
)

mixture_v <- poisson_reg(mixture = tune())
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
expect_equal(
mixture_v_glmnet$method$fit$args,
list(
x = expr(missing_arg()),
y = expr(missing_arg()),
weights = expr(missing_arg()),
alpha = new_quosure(tune()),
family = "poisson"
)
test_that('updating', {
expect_snapshot(
poisson_reg(penalty = 1) %>%
set_engine("glmnet", lambda.min.ratio = 0.001) %>%
update(mixture = tune())
)
})

test_that("engine arguments", {
glm_fam <- poisson_reg() %>% set_engine("glm", model = FALSE)
expect_equal(
translate(glm_fam)$method$fit$args,
list(
formula = expr(missing_arg()),
data = expr(missing_arg()),
weights = expr(missing_arg()),
model = new_empty_quosure(FALSE),
family = expr(stats::poisson)
)
)

glmnet_nlam <- poisson_reg() %>% set_engine("glmnet", nlambda = 10)
expect_equal(
translate(glmnet_nlam)$method$fit$args,
list(
x = expr(missing_arg()),
y = expr(missing_arg()),
weights = expr(missing_arg()),
nlambda = new_empty_quosure(10),
family = "poisson"
)
)

stan_samp <- poisson_reg() %>% set_engine("stan", chains = 1, iter = 5)
expect_equal(
translate(stan_samp)$method$fit$args,
list(
formula = expr(missing_arg()),
data = expr(missing_arg()),
weights = expr(missing_arg()),
chains = new_empty_quosure(1),
iter = new_empty_quosure(5),
family = expr(stats::poisson)
)
)
})


test_that("updating", {
expr1 <- poisson_reg() %>% set_engine("glm", model = FALSE)
expr1_exp <- poisson_reg(mixture = 0) %>% set_engine("glm", model = FALSE)

expr2 <- poisson_reg(mixture = varying()) %>% set_engine("glmnet")
expr2_exp <- poisson_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10)

expr3 <- poisson_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet")
expr3_exp <- poisson_reg(mixture = 1) %>% set_engine("glmnet")

expr4 <- poisson_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10)
expr4_exp <- poisson_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2)

expr5 <- poisson_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10)
expr5_exp <- poisson_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2)

expect_equal(update(expr1, mixture = 0), expr1_exp)
expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp)

param_tibb <- tibble::tibble(mixture = 1 / 3, penalty = 1)
param_list <- as.list(param_tibb)

expr4_updated <- update(expr4, param_tibb)
expect_equal(expr4_updated$args$mixture, 1 / 3)
expect_equal(expr4_updated$args$penalty, 1)
expect_equal(expr4_updated$eng_args$nlambda, rlang::quo(10))

expr4_updated_lst <- update(expr4, param_list)
expect_equal(expr4_updated_lst$args$mixture, 1 / 3)
expect_equal(expr4_updated_lst$args$penalty, 1)
expect_equal(expr4_updated_lst$eng_args$nlambda, rlang::quo(10))
})

test_that("bad input", {
expect_error(poisson_reg(mode = "classification"))
expect_error(translate(poisson_reg(), engine = "wat?"))
expect_error(translate(poisson_reg(), engine = NULL))
test_that('bad input', {
expect_error(poisson_reg(mode = "bogus"))
expect_error(translate(poisson_reg(mode = "regression"), engine = NULL))
expect_error(translate(poisson_reg(formula = y ~ x)))
expect_error(translate(poisson_reg(x = seniors[, 1:3], y = factor(seniors$count)) %>% set_engine("glmnet")))
expect_error(translate(poisson_reg(formula = y ~ x) %>% set_engine("glm")))
})

test_that("printing", {
expect_output(print(poisson_reg()))
})


# ------------------------------------------------------------------------------
# glm execution tests

test_that("glm execution", {
expect_error(
Expand Down Expand Up @@ -211,12 +67,3 @@ test_that("glm prediction", {
expect_equal(glm_pred, predict(res_form, seniors[1:3, ])$.pred)
})

test_that("newdata error trapping", {
res_xy <- fit_xy(
glm_spec,
x = seniors[, 1:3],
y = seniors$count,
control = ctrl
)
expect_error(predict(res_xy, newdata = seniors[1:3, 1:3]), "Did you mean")
})

0 comments on commit d42c32c

Please sign in to comment.