diff --git a/R/estimators.R b/R/estimators.R index 3766d78..ffca666 100644 --- a/R/estimators.R +++ b/R/estimators.R @@ -319,7 +319,7 @@ glmnet_model <- function(...) { estimates = NULL, sampler = glmnet_sampler, predictor = \(object, ...) { - glmnetUtils::predict(object, s = object$lambda.1se, ...)[, 1] + predict(object, s = object$lambda.1se, ...)[, 1] }, model_type = "glmnet_model()" ) diff --git a/tests/testthat/test-estimators.R b/tests/testthat/test-estimators.R new file mode 100644 index 0000000..6c28b8f --- /dev/null +++ b/tests/testthat/test-estimators.R @@ -0,0 +1,25 @@ + +library(multimedia) + +# Joy dataset example +exper <- mediation_data(demo_joy(), "PHQ", "treatment", starts_with("ASV")) + +test_that("Estimation with RF works in a simple case.", { + model <- multimedia(exper) |> + estimate(exper) + + predictions <- predict(model) + expect_equal(names(predictions), c("mediators", "outcomes")) + expect_equal(colnames(predictions$mediators), mediators(model)) + expect_equal(mediators(model), paste0("ASV", 1:5)) +}) + +test_that("Estimation with glmnet works in a simple case", { + model <- multimedia(exper, glmnet_model(lambda = .1)) |> + estimate(exper) + + predictions <- predict(model) + expect_equal(names(predictions), c("mediators", "outcomes")) + expect_equal(colnames(predictions$mediators), mediators(model)) + expect_equal(mediators(model), paste0("ASV", 1:5)) +})