Skip to content

Commit

Permalink
make progress bar optional
Browse files Browse the repository at this point in the history
  • Loading branch information
krisrs1128 committed Sep 3, 2024
1 parent 435fb7b commit 40d3a04
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
6 changes: 4 additions & 2 deletions R/contrast.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ indirect_overall <- function(model, exper = NULL, t1 = 1, t2 = 2) {
#' (pathwise) indirect effect.
#' @param t2 The alternative level of the treatment to be used when computing
#' the (pathwise) indirect effect.
#' @param progress A logical indicating whether to show a progress bar.
#' @return A data.frame summarizing the pathwise (per-mediator) indirect effects
#' associated with different settings of the direct effect.
#' @examples
Expand All @@ -279,7 +280,8 @@ indirect_overall <- function(model, exper = NULL, t1 = 1, t2 = 2) {
#' indirect_pathwise(fit)
#' @importFrom cli cli_text
#' @export
indirect_pathwise <- function(model, exper = NULL, t1 = 1, t2 = 2) {
indirect_pathwise <- function(
model, exper = NULL, t1 = 1, t2 = 2, progress = TRUE) {
pretreatment <- NULL
if (!is.null(exper)) {
pretreatment <- exper@pretreatments
Expand Down Expand Up @@ -314,7 +316,7 @@ indirect_pathwise <- function(model, exper = NULL, t1 = 1, t2 = 2) {
direct_setting = t_[[1]][i],
row.names = NULL
)
pb$tick()
if (progress) pb$tick()
k <- k + 1
}
}
Expand Down
25 changes: 18 additions & 7 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ sub_formula <- function(formula, yj) {
#' @param f A function for estimating a single response model given a formula
#' and input dataset. This is the model that we would like to parallelize
#' across responses.
#' @param progress A logical indicating whether to show a progress bar.
#' @return f_multi A function that takes a formula and dataset and applies f to
#' each response on the left hand side of the original formula.
#' @importFrom formula.tools lhs.vars
Expand All @@ -66,7 +67,7 @@ sub_formula <- function(formula, yj) {
#' prf <- parallelize(ranger::ranger)
#' prf(mpg + hp ~ wt + disp + cyl, data = mtcars)
#' @export
parallelize <- function(f) {
parallelize <- function(f, progress = TRUE) {
function(formula, ...) {
models <- list()

Expand All @@ -78,7 +79,7 @@ parallelize <- function(f) {
for (j in seq_along(ys)) {
fmla <- sub_formula(formula, ys[j])
models[[ys[j]]] <- f(fmla, ...)
pb$tick()
if (progress) pb$tick()
}
models
}
Expand Down Expand Up @@ -239,6 +240,8 @@ estimate <- function(model, exper) {
#' pretreatment variables, since each input is low-dimensional, even when there
#' are many responses.
#'
#' @param progress A logical indicating whether to show a progress bar during
#' estimation.
#' @return model An object of class `model` with estimator, predictor, and
#' sampler functions associated wtih a linear model.
#' @seealso model
Expand All @@ -247,7 +250,7 @@ estimate <- function(model, exper) {
#' m <- lm_model()
#' estimator(m)(mpg ~ hp + wt, data = mtcars)
#' @export
lm_model <- function() {
lm_model <- function(progress = TRUE) {
new(
"model",
estimator = parallelize(lm),
Expand Down Expand Up @@ -302,6 +305,8 @@ glmnet_model_params <- function(...) {
#' have many mediators or pretreatment variables, making the input
#' high-dimensional.
#'
#' @param progress A logical indicating whether to show a progress bar during
#' estimation.
#' @param ... Keyword parameters passed to glmnet.
#' @return model An object of class `model` with estimator, predictor, and
#' sampler functions associated wtih a lienar model.
Expand All @@ -322,7 +327,7 @@ glmnet_model_params <- function(...) {
#' multimedia(exper, glmnet_model(lambda = 0.1)) |>
#' estimate(exper)
#' @export
glmnet_model <- function(...) {
glmnet_model <- function(progress, ...) {
check_if_installed(
c("glmnet", "glmnetUtils"),
"to use a glmnet regression model for multimedia estimation."
Expand All @@ -333,7 +338,8 @@ glmnet_model <- function(...) {
new(
"model",
estimator = parallelize(
\(fmla, data) inject(glmnetUtils::glmnet(fmla, data, !!!params))
\(fmla, data) inject(glmnetUtils::glmnet(fmla, data, !!!params)),
progress = progress
),
estimates = NULL,
sampler = glmnet_sampler,
Expand Down Expand Up @@ -626,6 +632,8 @@ lnm_sampler <- function(fit, newdata = NULL, indices = NULL, ...) {
#' Internally, each of the models across the response are estimated using
#' ranger.
#'
#' @param progress A logical indicating whether to show a progress bar during
#' estimation.
#' @param ... Keyword parameters passed to ranger.
#' @return model An object of class `model` with estimator, predictor, and
#' sampler functions associated wtih a lienar model.
Expand All @@ -642,7 +650,7 @@ lnm_sampler <- function(fit, newdata = NULL, indices = NULL, ...) {
#' estimate(exper)
#' @seealso model lm_model rf_model glmnet_model brms_model
#' @export
rf_model <- function(...) {
rf_model <- function(progress = TRUE, ...) {
check_if_installed(
"ranger",
"to use a random forest model for multimedia estimation."
Expand All @@ -651,7 +659,10 @@ rf_model <- function(...) {

new(
"model",
estimator = parallelize(\(fmla, data) ranger::ranger(fmla, data, ...)),
estimator = parallelize(
\(fmla, data) ranger::ranger(fmla, data, ...),
progress
),
estimates = NULL,
sampler = rf_sampler,
model_type = "rf_model()",
Expand Down
5 changes: 3 additions & 2 deletions R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ nullify <- function(multimedia, nulls = NULL) {
#' @param exper An object of class multimedia_data containing the mediation and
#' outcome data from which the direct effects are to be estimated.
#' @param B The number of bootstrap samples. Defaults to 1000.
#' @param progress A logical indicating whether to show a progress bar.
#' @return stats A list of length B containing the results of the fs applied on
#' each of the B bootstrap resamples.
#' @importFrom progress progress_bar
Expand All @@ -139,7 +140,7 @@ nullify <- function(multimedia, nulls = NULL) {
#' ) +
#' ggplot2::facet_wrap(~outcome, scales = "free")
#' @export
bootstrap <- function(model, exper, fs = NULL, B = 1000) {
bootstrap <- function(model, exper, fs = NULL, B = 1000, progress = TRUE) {
if (is.null(fs)) {
fs <- list(direct_effect = direct_effect)
}
Expand All @@ -164,7 +165,7 @@ bootstrap <- function(model, exper, fs = NULL, B = 1000) {
# estimate model and effects
model_b <- estimate(model, exper_b)
stats[[nf]][[b]] <- fs[[f]](model_b, exper_b)
pb$tick()
if (progress) pb$tick()
}
stats[[nf]] <- bind_rows(stats[[nf]], .id = "bootstrap") |>
mutate(bootstrap = as.integer(bootstrap))
Expand Down

0 comments on commit 40d3a04

Please sign in to comment.