Skip to content

Commit

Permalink
glmnetUtils namespace, and add associated test
Browse files Browse the repository at this point in the history
  • Loading branch information
krisrs1128 committed Aug 14, 2024
1 parent 6ccf5da commit 112f80e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ glmnet_model <- function(...) {
estimates = NULL,
sampler = glmnet_sampler,
predictor = \(object, ...) {
glmnetUtils::predict(object, s = object$lambda.1se, ...)[, 1]
predict(object, s = object$lambda.1se, ...)[, 1]
},
model_type = "glmnet_model()"
)
Expand Down
25 changes: 25 additions & 0 deletions tests/testthat/test-estimators.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

library(multimedia)

# Joy dataset example
exper <- mediation_data(demo_joy(), "PHQ", "treatment", starts_with("ASV"))

test_that("Estimation with RF works in a simple case.", {
model <- multimedia(exper) |>
estimate(exper)

predictions <- predict(model)
expect_equal(names(predictions), c("mediators", "outcomes"))
expect_equal(colnames(predictions$mediators), mediators(model))
expect_equal(mediators(model), paste0("ASV", 1:5))
})

test_that("Estimation with glmnet works in a simple case", {
model <- multimedia(exper, glmnet_model(lambda = .1)) |>
estimate(exper)

predictions <- predict(model)
expect_equal(names(predictions), c("mediators", "outcomes"))
expect_equal(colnames(predictions$mediators), mediators(model))
expect_equal(mediators(model), paste0("ASV", 1:5))
})

0 comments on commit 112f80e

Please sign in to comment.